Skip to content

Commit

Permalink
Add combining state support (apache#22826)
Browse files Browse the repository at this point in the history
  • Loading branch information
damccorm authored and Kanishk Karanawat committed Sep 29, 2022
1 parent 2bbf8c1 commit 2d9f807
Show file tree
Hide file tree
Showing 8 changed files with 485 additions and 22 deletions.
4 changes: 2 additions & 2 deletions sdks/go/pkg/beam/core/graph/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1274,10 +1274,10 @@ func validateState(fn *DoFn, numIn mainInputs) error {
"unique per DoFn", k, orig, s)
}
t := s.StateType()
if t != state.TypeValue && t != state.TypeBag {
if t != state.TypeValue && t != state.TypeBag && t != state.TypeCombining {
err := errors.Errorf("Unrecognized state type %v for state %v", t, s)
return errors.SetTopLevelMsgf(err, "Unrecognized state type %v for state %v. Currently the only supported state"+
"type is state.Value and state.Bag", t, s)
"types are state.Value, state.Combining, and state.Bag", t, s)
}
stateKeys[k] = s
}
Expand Down
11 changes: 11 additions & 0 deletions sdks/go/pkg/beam/core/graph/fn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ func TestNewDoFn(t *testing.T) {
{dfn: &GoodDoFnCoGbk1wSide{}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn{State1: state.MakeValueState[int]("state1")}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn2{State1: state.MakeBagState[int]("state1")}, opt: NumMainInputs(MainKv)},
{dfn: &GoodStatefulDoFn3{State1: state.MakeCombiningState[int, int, int]("state1", func(a, b int) int {
return a * b
})}, opt: NumMainInputs(MainKv)},
}

for _, test := range tests {
Expand Down Expand Up @@ -1096,6 +1099,14 @@ func (fn *GoodStatefulDoFn2) ProcessElement(state.Provider, int, int) int {
return 0
}

type GoodStatefulDoFn3 struct {
State1 state.Combining[int, int, int]
}

func (fn *GoodStatefulDoFn3) ProcessElement(state.Provider, int, int) int {
return 0
}

// Examples of incorrect SDF signatures.
// Examples with missing methods.

Expand Down
20 changes: 18 additions & 2 deletions sdks/go/pkg/beam/core/runtime/exec/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -467,13 +467,29 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {

if len(userState) > 0 {
stateIDToCoder := make(map[string]*coder.Coder)
stateIDToCombineFn := make(map[string]*graph.CombineFn)
for key, spec := range userState {
// TODO(#22736) - this will eventually need to be aware of which type of state its modifying to support non-Value state types.
var cID string
if rmw := spec.GetReadModifyWriteSpec(); rmw != nil {
cID = rmw.CoderId
} else if bs := spec.GetBagSpec(); bs != nil {
cID = bs.ElementCoderId
} else if cs := spec.GetCombiningSpec(); cs != nil {
cID = cs.AccumulatorCoderId
cmbData := string(cs.GetCombineFn().GetPayload())
var cmbTp v1pb.TransformPayload
if err := protox.DecodeBase64(cmbData, &cmbTp); err != nil {
return nil, errors.Wrapf(err, "invalid transform payload %v for %v", cmbData, transform)
}
_, fn, _, _, _, err := graphx.DecodeMultiEdge(cmbTp.GetEdge())
if err != nil {
return nil, err
}
cfn, err := graph.AsCombineFn(fn)
if err != nil {
return nil, err
}
stateIDToCombineFn[key] = cfn
}
c, err := b.coders.Coder(cID)
if err != nil {
Expand All @@ -489,7 +505,7 @@ func (b *builder) makeLink(from string, id linkID) (Node, error) {
if err != nil {
return nil, err
}
n.UState = NewUserStateAdapter(sid, coder.NewW(ec, wc), stateIDToCoder)
n.UState = NewUserStateAdapter(sid, coder.NewW(ec, wc), stateIDToCoder, stateIDToCombineFn)
}
}

Expand Down
53 changes: 46 additions & 7 deletions sdks/go/pkg/beam/core/runtime/exec/userstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import (
"fmt"
"io"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
)

type stateProvider struct {
Expand All @@ -39,6 +41,7 @@ type stateProvider struct {
appendersByKey map[string]io.Writer
clearersByKey map[string]io.Writer
codersByKey map[string]*coder.Coder
combineFnsByKey map[string]*graph.CombineFn
}

// ReadValueState reads a value state from the State API
Expand Down Expand Up @@ -159,6 +162,40 @@ func (s *stateProvider) WriteBagState(val state.Transaction) error {
return nil
}

func (s *stateProvider) CreateAccumulatorFn(userStateID string) reflectx.Func {
a := s.combineFnsByKey[userStateID]
if ca := a.CreateAccumulatorFn(); ca != nil {
return ca.Fn
}
return nil
}

func (s *stateProvider) AddInputFn(userStateID string) reflectx.Func {
a := s.combineFnsByKey[userStateID]
if ai := a.AddInputFn(); ai != nil {
return ai.Fn
}

return nil
}

func (s *stateProvider) MergeAccumulatorsFn(userStateID string) reflectx.Func {
a := s.combineFnsByKey[userStateID]
if ma := a.MergeAccumulatorsFn(); ma != nil {
return ma.Fn
}

return nil
}

func (s *stateProvider) ExtractOutputFn(userStateID string) reflectx.Func {
a := s.combineFnsByKey[userStateID]
if eo := a.ExtractOutputFn(); eo != nil {
return eo.Fn
}
return nil
}

func (s *stateProvider) getReader(userStateID string) (io.ReadCloser, error) {
if r, ok := s.readersByKey[userStateID]; ok {
return r, nil
Expand Down Expand Up @@ -201,16 +238,17 @@ type UserStateAdapter interface {
}

type userStateAdapter struct {
sid StreamID
wc WindowEncoder
kc ElementEncoder
stateIDToCoder map[string]*coder.Coder
c *coder.Coder
sid StreamID
wc WindowEncoder
kc ElementEncoder
stateIDToCoder map[string]*coder.Coder
stateIDToCombineFn map[string]*graph.CombineFn
c *coder.Coder
}

// NewUserStateAdapter returns a user state adapter for the given StreamID and coder.
// It expects a W<V> or W<KV<K,V>> coder, because the protocol requires windowing information.
func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string]*coder.Coder) UserStateAdapter {
func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string]*coder.Coder, stateIDToCombineFn map[string]*graph.CombineFn) UserStateAdapter {
if !coder.IsW(c) {
panic(fmt.Sprintf("expected WV coder for user state %v: %v", sid, c))
}
Expand All @@ -220,7 +258,7 @@ func NewUserStateAdapter(sid StreamID, c *coder.Coder, stateIDToCoder map[string
if coder.IsKV(coder.SkipW(c)) {
kc = MakeElementEncoder(coder.SkipW(c).Components[0])
}
return &userStateAdapter{sid: sid, wc: wc, kc: kc, c: c, stateIDToCoder: stateIDToCoder}
return &userStateAdapter{sid: sid, wc: wc, kc: kc, c: c, stateIDToCoder: stateIDToCoder, stateIDToCombineFn: stateIDToCombineFn}
}

// NewStateProvider creates a stateProvider with the ability to talk to the state API.
Expand Down Expand Up @@ -249,6 +287,7 @@ func (s *userStateAdapter) NewStateProvider(ctx context.Context, reader StateRea
readersByKey: make(map[string]io.ReadCloser),
appendersByKey: make(map[string]io.Writer),
clearersByKey: make(map[string]io.Writer),
combineFnsByKey: s.stateIDToCombineFn,
codersByKey: s.stateIDToCoder,
}

Expand Down
4 changes: 3 additions & 1 deletion sdks/go/pkg/beam/core/runtime/exec/userstate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"reflect"
"testing"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/coder"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/runtime/coderx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
Expand Down Expand Up @@ -86,7 +87,8 @@ func buildStateProvider() stateProvider {
readersByKey: make(map[string]io.ReadCloser),
appendersByKey: make(map[string]io.Writer),
clearersByKey: make(map[string]io.Writer),
codersByKey: make(map[string]*coder.Coder), // Each test can specify coders as needed
combineFnsByKey: make(map[string]*graph.CombineFn), // Each test can specify coders as needed
codersByKey: make(map[string]*coder.Coder), // Each test can specify coders as needed
}
}

Expand Down
32 changes: 32 additions & 0 deletions sdks/go/pkg/beam/core/runtime/graphx/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,38 @@ func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
Urn: URNBagUserState,
},
}
case state.TypeCombining:
cps := ps.(state.CombiningPipelineState).GetCombineFn()
f, err := graph.NewFn(cps)
if err != nil {
return handleErr(err)
}
cf, err := graph.AsCombineFn(f)
if err != nil {
return handleErr(err)
}
me := graph.MultiEdge{
Op: graph.Combine,
CombineFn: cf,
}
mustEncodeMultiEdge, err := mustEncodeMultiEdgeBase64(&me)
if err != nil {
return handleErr(err)
}
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
Spec: &pipepb.StateSpec_CombiningSpec{
CombiningSpec: &pipepb.CombiningStateSpec{
AccumulatorCoderId: coderID,
CombineFn: &pipepb.FunctionSpec{
Urn: "beam:combinefn:gosdk:v1",
Payload: []byte(mustEncodeMultiEdge),
},
},
},
Protocol: &pipepb.FunctionSpec{
Urn: URNBagUserState,
},
}
default:
return nil, errors.Errorf("State type %v not recognized for state %v", ps.StateKey(), ps)
}
Expand Down
Loading

0 comments on commit 2d9f807

Please sign in to comment.