Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add combining state support #22826

Merged
merged 2 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sdks/go/pkg/beam/core/graph/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1274,10 +1274,10 @@ func validateState(fn *DoFn, numIn mainInputs) error {
"unique per DoFn", k, orig, s)
}
t := s.StateType()
if t != state.StateTypeValue && t != state.StateTypeBag {
if t != state.StateTypeValue && t != state.StateTypeBag && t != state.StateTypeCombining {
err := errors.Errorf("Unrecognized state type %v for state %v", t, s)
return errors.SetTopLevelMsgf(err, "Unrecognized state type %v for state %v. Currently the only supported state"+
"type is state.Value and state.Bag", t, s)
"types are state.Value, state.Combining, and state.Bag", t, s)
}
stateKeys[k] = s
}
Expand Down
11 changes: 11 additions & 0 deletions sdks/go/pkg/beam/core/graph/fn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ func TestNewDoFn(t *testing.T) {
{dfn: &GoodDoFnCoGbk1wSide{}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn{State1: state.MakeValueState[int]("state1")}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn2{State1: state.MakeBagState[int]("state1")}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn3{State1: state.MakeCombiningState[int, int, int]("state1", func(a, b int) int {
return a * b
})}, opt: NumMainInputs(MainKv)},
}

for _, test := range tests {
Expand Down Expand Up @@ -1096,6 +1099,14 @@ func (fn *GoodStatefulDoFn2) ProcessElement(state.Provider, int, int) int {
return 0
}

type GoodStatefulDoFn3 struct {
State1 state.Combining[int, int, int]
}

func (fn *GoodStatefulDoFn3) ProcessElement(state.Provider, int, int) int {
return 0
}

// Examples of incorrect SDF signatures.
// Examples with missing methods.

Expand Down
20 changes: 18 additions & 2 deletions sdks/go/pkg/beam/core/runtime/exec/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,29 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {

if len(userState) > 0 {
stateIDToCoder := make(map[string]*coder.Coder)
stateIDToCombineFn := make(map[string]*graph.CombineFn)
for key, spec := range userState {
// TODO(#22736) - this will eventually need to be aware of which type of state its modifying to support non-Value state types.
var cID string
if rmw := spec.GetReadModifyWriteSpec(); rmw != nil {
cID = rmw.CoderId
} else if bs := spec.GetBagSpec(); bs != nil {
cID = bs.ElementCoderId
} else if cs := spec.GetCombiningSpec(); cs != nil {
cID = cs.AccumulatorCoderId
cmbData := string(cs.GetCombineFn().GetPayload())
var cmbTp v1pb.TransformPayload
if err := protox.DecodeBase64(cmbData, &cmbTp); err != nil {
return nil, errors.Wrapf(err, "invalid transform payload %v for %v", cmbData, transform)
}
_, fn, _, _, _, err := graphx.DecodeMultiEdge(cmbTp.GetEdge())
if err != nil {
return nil, err
}
cfn, err := graph.AsCombineFn(fn)
if err != nil {
return nil, err
}
stateIDToCombineFn[key] = cfn
}
c, err := b.coders.Coder(cID)
if err != nil {
Expand All @@ -489,7 +505,7 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
if err != nil {
return nil, err
}
n.UState = NewUserStateAdapter(sid, coder.NewW(ec, wc), stateIDToCoder)
n.UState = NewUserStateAdapter(sid, coder.NewW(ec, wc), stateIDToCoder, stateIDToCombineFn)
}
}

Expand Down
53 changes: 46 additions & 7 deletions sdks/go/pkg/beam/core/runtime/exec/userstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import (
"fmt"
"io"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
)

type stateProvider struct {
Expand All @@ -39,6 +41,7 @@ type stateProvider struct {
appendersByKey map[string]io.Writer
clearersByKey map[string]io.Writer
codersByKey map[string]*coder.Coder
combineFnsByKey map[string]*graph.CombineFn
}

// ReadValueState reads a value state from the State API
Expand Down Expand Up @@ -159,6 +162,40 @@ func (s *stateProvider) WriteBagState(val state.Transaction) error {
return nil
}

func (s *stateProvider) CreateAccumulatorFn(userStateID string) reflectx.Func {
a := s.combineFnsByKey[userStateID]
if ca := a.CreateAccumulatorFn(); ca != nil {
return ca.Fn
}
return nil
}

func (s *stateProvider) AddInputFn(userStateID string) reflectx.Func {
a := s.combineFnsByKey[userStateID]
if ai := a.AddInputFn(); ai != nil {
return ai.Fn
}

return nil
}

func (s *stateProvider) MergeAccumulatorsFn(userStateID string) reflectx.Func {
a := s.combineFnsByKey[userStateID]
if ma := a.MergeAccumulatorsFn(); ma != nil {
return ma.Fn
}

return nil
}

func (s *stateProvider) ExtractOutputFn(userStateID string) reflectx.Func {
a := s.combineFnsByKey[userStateID]
if eo := a.ExtractOutputFn(); eo != nil {
return eo.Fn
}
return nil
}

func (s *stateProvider) getReader(userStateID string) (io.ReadCloser, error) {
if r, ok := s.readersByKey[userStateID]; ok {
return r, nil
Expand Down Expand Up @@ -203,16 +240,17 @@ type UserStateAdapter interface {
}

type userStateAdapter struct {
sid StreamID
wc WindowEncoder
kc ElementEncoder
stateIDToCoder map[string]*coder.Coder
c *coder.Coder
sid StreamID
wc WindowEncoder
kc ElementEncoder
stateIDToCoder map[string]*coder.Coder
stateIDToCombineFn map[string]*graph.CombineFn
c *coder.Coder
}

// NewUserStateAdapter returns a user state adapter for the given StreamID and coder.
// It expects a W<V> or W<KV<K,V>> coder, because the protocol requires windowing information.
func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string]*coder.Coder) UserStateAdapter {
func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string]*coder.Coder, stateIDToCombineFn map[string]*graph.CombineFn) UserStateAdapter {
if !coder.IsW(c) {
panic(fmt.Sprintf("expected WV coder for user state %v: %v", sid, c))
}
Expand All @@ -222,7 +260,7 @@ func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string
if coder.IsKV(coder.SkipW(c)) {
kc = MakeElementEncoder(coder.SkipW(c).Components[0])
}
return &userStateAdapter{sid: sid, wc: wc, kc: kc, c: c, stateIDToCoder: stateIDToCoder}
return &userStateAdapter{sid: sid, wc: wc, kc: kc, c: c, stateIDToCoder: stateIDToCoder, stateIDToCombineFn: stateIDToCombineFn}
}

// NewStateProvider creates a stateProvider with the ability to talk to the state API.
Expand Down Expand Up @@ -251,6 +289,7 @@ func (s *userStateAdapter) NewStateProvider(ctx context.Context, reader StateRea
readersByKey: make(map[string]io.ReadCloser),
appendersByKey: make(map[string]io.Writer),
clearersByKey: make(map[string]io.Writer),
combineFnsByKey: s.stateIDToCombineFn,
codersByKey: s.stateIDToCoder,
}

Expand Down
4 changes: 3 additions & 1 deletion sdks/go/pkg/beam/core/runtime/exec/userstate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"reflect"
"testing"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/coderx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
Expand Down Expand Up @@ -86,7 +87,8 @@ func buildStateProvider() stateProvider {
readersByKey: make(map[string]io.ReadCloser),
appendersByKey: make(map[string]io.Writer),
clearersByKey: make(map[string]io.Writer),
codersByKey: make(map[string]*coder.Coder), // Each test can specify coders as needed
combineFnsByKey: make(map[string]*graph.CombineFn), // Each test can specify coders as needed
codersByKey: make(map[string]*coder.Coder), // Each test can specify coders as needed
}
}

Expand Down
32 changes: 32 additions & 0 deletions sdks/go/pkg/beam/core/runtime/graphx/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,38 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
Urn: URNBagUserState,
},
}
case state.StateTypeCombining:
cps := ps.(state.CombiningPipelineState).GetCombineFn()
f, err := graph.NewFn(cps)
if err != nil {
return handleErr(err)
}
cf, err := graph.AsCombineFn(f)
if err != nil {
return handleErr(err)
}
me := graph.MultiEdge{
Op: graph.Combine,
CombineFn: cf,
}
mustEncodeMultiEdge, err := mustEncodeMultiEdgeBase64(&me)
if err != nil {
return handleErr(err)
}
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
Spec: &pipepb.StateSpec_CombiningSpec{
CombiningSpec: &pipepb.CombiningStateSpec{
AccumulatorCoderId: coderID,
CombineFn: &pipepb.FunctionSpec{
Urn: "beam:combinefn:gosdk:v1",
Payload: []byte(mustEncodeMultiEdge),
},
},
},
Protocol: &pipepb.FunctionSpec{
Urn: URNBagUserState,
},
}
default:
return nil, errors.Errorf("State type %v not recognized for state %v", ps.StateKey(), ps)
}
Expand Down
Loading