Skip to content

Commit

Permalink
Combining state integration test (apache#22846)
Browse files Browse the repository at this point in the history
* Add combining state support

* WIP: combining state integration test

* Check for errors

* accum -> combine

* Bad string replace

* Missed error handling
  • Loading branch information
damccorm authored Aug 25, 2022
1 parent f973acc commit a58f3d2
Show file tree
Hide file tree
Showing 3 changed files with 135 additions and 1 deletion.
3 changes: 3 additions & 0 deletions sdks/go/test/integration/integration.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ var directFilters = []string{
// The direct runner does not support user state.
"TestValueState",
"TestBagState",
"TestCombiningState",
}

var portableFilters = []string{
Expand All @@ -114,6 +115,7 @@ var portableFilters = []string{
// The portable runner does not support user state.
"TestValueState",
"TestBagState",
"TestCombiningState",
}

var flinkFilters = []string{
Expand Down Expand Up @@ -158,6 +160,7 @@ var samzaFilters = []string{
// The samza runner does not support user state.
"TestValueState",
"TestBagState",
"TestCombiningState",
}

var sparkFilters = []string{
Expand Down
128 changes: 127 additions & 1 deletion sdks/go/test/integration/primitives/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package primitives

import (
"fmt"
"strconv"
"strings"

"github.com/apache/beam/sdks/v2/go/pkg/beam"
Expand All @@ -28,7 +29,12 @@ import (
func init() {
register.DoFn3x1[state.Provider, string, int, string](&valueStateFn{})
register.DoFn3x1[state.Provider, string, int, string](&bagStateFn{})
register.DoFn3x1[state.Provider, string, int, string](&combiningStateFn{})
register.Emitter2[string, int]()
register.Combiner1[int](&combine1{})
register.Combiner2[string, int](&combine2{})
register.Combiner2[string, int](&combine3{})
register.Combiner1[int](&combine4{})
}

type valueStateFn struct {
Expand Down Expand Up @@ -113,7 +119,7 @@ func (f *bagStateFn) ProcessElement(s state.Provider, w string, c int) string {
return fmt.Sprintf("%s: %v, %s", w, sum, strings.Join(j, ","))
}

// ValueStateParDo tests a DoFn that uses value state.
// BagStateParDo tests a DoFn that uses value state.
func BagStateParDo() *beam.Pipeline {
p, s := beam.NewPipelineWithRoot()

Expand All @@ -126,3 +132,123 @@ func BagStateParDo() *beam.Pipeline {

return p
}

type combiningStateFn struct {
State0 state.Combining[int, int, int]
State1 state.Combining[int, int, int]
State2 state.Combining[string, string, int]
State3 state.Combining[string, string, int]
State4 state.Combining[int, int, int]
}

type combine1 struct{}

func (ac *combine1) MergeAccumulators(a, b int) int {
return a + b
}

type combine2 struct{}

func (ac *combine2) MergeAccumulators(a, b string) string {
ai, _ := strconv.Atoi(a)
bi, _ := strconv.Atoi(b)
return strconv.Itoa(ai + bi)
}

func (ac *combine2) ExtractOutput(a string) int {
ai, _ := strconv.Atoi(a)
return ai
}

type combine3 struct{}

func (ac *combine3) CreateAccumulator() string {
return "0"
}

func (ac *combine3) MergeAccumulators(a string, b string) string {
ai, _ := strconv.Atoi(a)
bi, _ := strconv.Atoi(b)
return strconv.Itoa(ai + bi)
}

func (ac *combine3) ExtractOutput(a string) int {
ai, _ := strconv.Atoi(a)
return ai
}

type combine4 struct{}

func (ac *combine4) AddInput(a, b int) int {
return a + b
}

func (ac *combine4) MergeAccumulators(a, b int) int {
return a + b
}

func (f *combiningStateFn) ProcessElement(s state.Provider, w string, c int) string {
i, _, err := f.State0.Read(s)
if err != nil {
panic(err)
}
err = f.State0.Add(s, 1)
if err != nil {
panic(err)
}
i1, _, err := f.State1.Read(s)
if err != nil {
panic(err)
}
err = f.State1.Add(s, 1)
if err != nil {
panic(err)
}
i2, _, err := f.State2.Read(s)
if err != nil {
panic(err)
}
err = f.State2.Add(s, "1")
if err != nil {
panic(err)
}
i3, _, err := f.State3.Read(s)
if err != nil {
panic(err)
}
err = f.State3.Add(s, "1")
if err != nil {
panic(err)
}
i4, _, err := f.State4.Read(s)
if err != nil {
panic(err)
}
err = f.State4.Add(s, 1)
if err != nil {
panic(err)
}
return fmt.Sprintf("%s: %v %v %v %v %v", w, i, i1, i2, i3, i4)
}

// CombiningStateParDo tests a DoFn that uses value state.
func CombiningStateParDo() *beam.Pipeline {
p, s := beam.NewPipelineWithRoot()

in := beam.Create(s, "apple", "pear", "peach", "apple", "apple", "pear")
keyed := beam.ParDo(s, func(w string, emit func(string, int)) {
emit(w, 1)
}, in)
counts := beam.ParDo(s, &combiningStateFn{
State0: state.MakeCombiningState[int, int, int]("key0", func(a, b int) int {
return a + b
}),
State1: state.Combining[int, int, int](state.MakeCombiningState[int, int, int]("key1", &combine1{})),
State2: state.Combining[string, string, int](state.MakeCombiningState[string, string, int]("key2", &combine2{})),
State3: state.Combining[string, string, int](state.MakeCombiningState[string, string, int]("key3", &combine3{})),
State4: state.Combining[int, int, int](state.MakeCombiningState[int, int, int]("key4", &combine4{}))},
keyed)
passert.Equals(s, counts, "apple: 0 0 0 0 0", "pear: 0 0 0 0 0", "peach: 0 0 0 0 0", "apple: 1 1 1 1 1", "apple: 2 2 2 2 2", "pear: 1 1 1 1 1")

return p
}
5 changes: 5 additions & 0 deletions sdks/go/test/integration/primitives/state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,8 @@ func TestBagState(t *testing.T) {
integration.CheckFilters(t)
ptest.RunAndValidate(t, BagStateParDo())
}

func TestCombiningState(t *testing.T) {
integration.CheckFilters(t)
ptest.RunAndValidate(t, CombiningStateParDo())
}

0 comments on commit a58f3d2

Please sign in to comment.