From a58f3d2c89fc88166d3f4af7c2d9b4c463d7718a Mon Sep 17 00:00:00 2001 From: Danny McCormick Date: Thu, 25 Aug 2022 12:46:21 -0400 Subject: [PATCH] Combining state integration test (#22846) * Add combining state support * WIP: combining state integration test * Check for errors * accum -> combine * Bad string replace * Missed error handling --- sdks/go/test/integration/integration.go | 3 + sdks/go/test/integration/primitives/state.go | 128 +++++++++++++++++- .../test/integration/primitives/state_test.go | 5 + 3 files changed, 135 insertions(+), 1 deletion(-) diff --git a/sdks/go/test/integration/integration.go b/sdks/go/test/integration/integration.go index e41ae96ee730e..0802ba12aeb7c 100644 --- a/sdks/go/test/integration/integration.go +++ b/sdks/go/test/integration/integration.go @@ -91,6 +91,7 @@ var directFilters = []string{ // The direct runner does not support user state. "TestValueState", "TestBagState", + "TestCombiningState", } var portableFilters = []string{ @@ -114,6 +115,7 @@ var portableFilters = []string{ // The portable runner does not support user state. "TestValueState", "TestBagState", + "TestCombiningState", } var flinkFilters = []string{ @@ -158,6 +160,7 @@ var samzaFilters = []string{ // The samza runner does not support user state. "TestValueState", "TestBagState", + "TestCombiningState", } var sparkFilters = []string{ diff --git a/sdks/go/test/integration/primitives/state.go b/sdks/go/test/integration/primitives/state.go index 314d6280fffb3..5ab01fb64f188 100644 --- a/sdks/go/test/integration/primitives/state.go +++ b/sdks/go/test/integration/primitives/state.go @@ -17,6 +17,7 @@ package primitives import ( "fmt" + "strconv" "strings" "github.com/apache/beam/sdks/v2/go/pkg/beam" @@ -28,7 +29,12 @@ import ( func init() { register.DoFn3x1[state.Provider, string, int, string](&valueStateFn{}) register.DoFn3x1[state.Provider, string, int, string](&bagStateFn{}) + register.DoFn3x1[state.Provider, string, int, string](&combiningStateFn{}) register.Emitter2[string, int]() + register.Combiner1[int](&combine1{}) + register.Combiner2[string, int](&combine2{}) + register.Combiner2[string, int](&combine3{}) + register.Combiner1[int](&combine4{}) } type valueStateFn struct { @@ -113,7 +119,7 @@ func (f *bagStateFn) ProcessElement(s state.Provider, w string, c int) string { return fmt.Sprintf("%s: %v, %s", w, sum, strings.Join(j, ",")) } -// ValueStateParDo tests a DoFn that uses value state. +// BagStateParDo tests a DoFn that uses value state. func BagStateParDo() *beam.Pipeline { p, s := beam.NewPipelineWithRoot() @@ -126,3 +132,123 @@ func BagStateParDo() *beam.Pipeline { return p } + +type combiningStateFn struct { + State0 state.Combining[int, int, int] + State1 state.Combining[int, int, int] + State2 state.Combining[string, string, int] + State3 state.Combining[string, string, int] + State4 state.Combining[int, int, int] +} + +type combine1 struct{} + +func (ac *combine1) MergeAccumulators(a, b int) int { + return a + b +} + +type combine2 struct{} + +func (ac *combine2) MergeAccumulators(a, b string) string { + ai, _ := strconv.Atoi(a) + bi, _ := strconv.Atoi(b) + return strconv.Itoa(ai + bi) +} + +func (ac *combine2) ExtractOutput(a string) int { + ai, _ := strconv.Atoi(a) + return ai +} + +type combine3 struct{} + +func (ac *combine3) CreateAccumulator() string { + return "0" +} + +func (ac *combine3) MergeAccumulators(a string, b string) string { + ai, _ := strconv.Atoi(a) + bi, _ := strconv.Atoi(b) + return strconv.Itoa(ai + bi) +} + +func (ac *combine3) ExtractOutput(a string) int { + ai, _ := strconv.Atoi(a) + return ai +} + +type combine4 struct{} + +func (ac *combine4) AddInput(a, b int) int { + return a + b +} + +func (ac *combine4) MergeAccumulators(a, b int) int { + return a + b +} + +func (f *combiningStateFn) ProcessElement(s state.Provider, w string, c int) string { + i, _, err := f.State0.Read(s) + if err != nil { + panic(err) + } + err = f.State0.Add(s, 1) + if err != nil { + panic(err) + } + i1, _, err := f.State1.Read(s) + if err != nil { + panic(err) + } + err = f.State1.Add(s, 1) + if err != nil { + panic(err) + } + i2, _, err := f.State2.Read(s) + if err != nil { + panic(err) + } + err = f.State2.Add(s, "1") + if err != nil { + panic(err) + } + i3, _, err := f.State3.Read(s) + if err != nil { + panic(err) + } + err = f.State3.Add(s, "1") + if err != nil { + panic(err) + } + i4, _, err := f.State4.Read(s) + if err != nil { + panic(err) + } + err = f.State4.Add(s, 1) + if err != nil { + panic(err) + } + return fmt.Sprintf("%s: %v %v %v %v %v", w, i, i1, i2, i3, i4) +} + +// CombiningStateParDo tests a DoFn that uses value state. +func CombiningStateParDo() *beam.Pipeline { + p, s := beam.NewPipelineWithRoot() + + in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear") + keyed := beam.ParDo(s, func(w string, emit func(string, int)) { + emit(w, 1) + }, in) + counts := beam.ParDo(s, &combiningStateFn{ + State0: state.MakeCombiningState[int, int, int]("key0", func(a, b int) int { + return a + b + }), + State1: state.Combining[int, int, int](state.MakeCombiningState[int, int, int]("key1", &combine1{})), + State2: state.Combining[string, string, int](state.MakeCombiningState[string, string, int]("key2", &combine2{})), + State3: state.Combining[string, string, int](state.MakeCombiningState[string, string, int]("key3", &combine3{})), + State4: state.Combining[int, int, int](state.MakeCombiningState[int, int, int]("key4", &combine4{}))}, + keyed) + passert.Equals(s, counts, "apple: 0 0 0 0 0", "pear: 0 0 0 0 0", "peach: 0 0 0 0 0", "apple: 1 1 1 1 1", "apple: 2 2 2 2 2", "pear: 1 1 1 1 1") + + return p +} diff --git a/sdks/go/test/integration/primitives/state_test.go b/sdks/go/test/integration/primitives/state_test.go index 6bf108b3df92e..562c04356250a 100644 --- a/sdks/go/test/integration/primitives/state_test.go +++ b/sdks/go/test/integration/primitives/state_test.go @@ -31,3 +31,8 @@ func TestBagState(t *testing.T) { integration.CheckFilters(t) ptest.RunAndValidate(t, BagStateParDo()) } + +func TestCombiningState(t *testing.T) { + integration.CheckFilters(t) + ptest.RunAndValidate(t, CombiningStateParDo()) +}