Skip to content

Commit

Permalink
Go stateful DoFns user side changes (apache#22761)
Browse files Browse the repository at this point in the history
* Go stateful DoFns user side changes

* Fix static check violation

* Small cleanup

* Doc comments
  • Loading branch information
damccorm authored and MarcoRob committed Aug 26, 2022
1 parent d422d1c commit 1645a2d
Show file tree
Hide file tree
Showing 14 changed files with 715 additions and 179 deletions.
45 changes: 44 additions & 1 deletion sdks/go/pkg/beam/core/funcx/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"reflect"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
Expand Down Expand Up @@ -82,6 +83,8 @@ const (
FnBundleFinalization FnParamKind = 0x800
// FnWatermarkEstimator indicates a function input parameter that implements sdf.WatermarkEstimator
FnWatermarkEstimator FnParamKind = 0x1000
// FnState indicates a function input parameter that implements state.Provider
FnStateProvider FnParamKind = 0x2000
)

func (k FnParamKind) String() string {
Expand Down Expand Up @@ -112,6 +115,8 @@ func (k FnParamKind) String() string {
return "BundleFinalization"
case FnWatermarkEstimator:
return "WatermarkEstimator"
case FnStateProvider:
return "StateProvider"
default:
return fmt.Sprintf("%v", int(k))
}
Expand Down Expand Up @@ -289,6 +294,17 @@ func (u *Fn) BundleFinalization() (pos int, exists bool) {
return -1, false
}

// StateProvider returns (index, true) iff the function expects a
// parameter that implements state.Provider.
func (u *Fn) StateProvider() (pos int, exists bool) {
for i, p := range u.Param {
if p.Kind == FnStateProvider {
return i, true
}
}
return -1, false
}

// WatermarkEstimator returns (index, true) iff the function expects a
// parameter that implements sdf.WatermarkEstimator.
func (u *Fn) WatermarkEstimator() (pos int, exists bool) {
Expand Down Expand Up @@ -374,6 +390,8 @@ func New(fn reflectx.Func) (*Fn, error) {
kind = FnWindow
case t == typex.BundleFinalizationType:
kind = FnBundleFinalization
case t == state.ProviderType:
kind = FnStateProvider
case t == reflectx.Type:
kind = FnType
case t.Implements(reflect.TypeOf((*sdf.RTracker)(nil)).Elem()):
Expand Down Expand Up @@ -464,7 +482,7 @@ func SubReturns(list []ReturnParam, indices ...int) []ReturnParam {
}

// The order of present parameters and return values must be as follows:
// func(FnContext?, FnPane?, FnWindow?, FnEventTime?, FnWatermarkEstimator?, FnType?, FnBundleFinalization?, FnRTracker?, (FnValue, SideInput*)?, FnEmit*) (RetEventTime?, RetOutput?, RetError?)
// func(FnContext?, FnPane?, FnWindow?, FnEventTime?, FnWatermarkEstimator?, FnType?, FnBundleFinalization?, FnRTracker?, FnStateProvider?, (FnValue, SideInput*)?, FnEmit*) (RetEventTime?, RetOutput?, RetError?)
// where ? indicates 0 or 1, and * indicates any number.
// and a SideInput is one of FnValue or FnIter or FnReIter
// Note: Fns with inputs must have at least one FnValue as the main input.
Expand Down Expand Up @@ -496,6 +514,7 @@ var (
errReflectTypePrecedence = errors.New("may only have a single reflect.Type parameter and it must precede the main input parameter")
errRTrackerPrecedence = errors.New("may only have a single sdf.RTracker parameter and it must precede the main input parameter")
errBundleFinalizationPrecedence = errors.New("may only have a single BundleFinalization parameter and it must precede the main input parameter")
errStateProviderPrecedence = errors.New("may only have a single state.Provider parameter and it must precede the main input parameter")
errInputPrecedence = errors.New("inputs parameters must precede emit function parameters")
)

Expand All @@ -513,6 +532,7 @@ const (
psOutput
psRTracker
psBundleFinalization
psStateProvider
)

func nextParamState(cur paramState, transition FnParamKind) (paramState, error) {
Expand All @@ -535,6 +555,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psContext:
switch transition {
Expand All @@ -552,6 +574,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psPane:
switch transition {
Expand All @@ -567,6 +591,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psWindow:
switch transition {
Expand All @@ -580,6 +606,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psEventTime:
switch transition {
Expand All @@ -591,6 +619,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psWatermarkEstimator:
switch transition {
Expand All @@ -600,20 +630,31 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psType:
switch transition {
case FnBundleFinalization:
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psBundleFinalization:
switch transition {
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psRTracker:
switch transition {
case FnStateProvider:
return psStateProvider, nil
}
case psStateProvider:
// Completely handled by the default clause
case psInput:
switch transition {
Expand Down Expand Up @@ -644,6 +685,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return -1, errBundleFinalizationPrecedence
case FnRTracker:
return -1, errRTrackerPrecedence
case FnStateProvider:
return -1, errStateProviderPrecedence
case FnIter, FnReIter, FnValue, FnMultiMap:
return psInput, nil
case FnEmit:
Expand Down
16 changes: 16 additions & 0 deletions sdks/go/pkg/beam/core/funcx/fn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
)
Expand Down Expand Up @@ -108,6 +109,11 @@ func TestNew(t *testing.T) {
Fn: func(typex.PaneInfo, typex.Window, typex.EventTime, sdf.WatermarkEstimator, reflect.Type, []byte) {},
Param: []FnParamKind{FnPane, FnWindow, FnEventTime, FnWatermarkEstimator, FnType, FnValue},
},
{
Name: "good10",
Fn: func(sdf.RTracker, state.Provider, []byte) {},
Param: []FnParamKind{FnRTracker, FnStateProvider, FnValue},
},
{
Name: "good-method",
Fn: foo{1}.Do,
Expand Down Expand Up @@ -211,6 +217,16 @@ func TestNew(t *testing.T) {
Fn: func(int, func(int), func() func(*int) bool) {},
Err: errInputPrecedence,
},
{
Name: "errInputPrecedence- StateProvider before RTracker",
Fn: func(state.Provider, sdf.RTracker, int) {},
Err: errRTrackerPrecedence,
},
{
Name: "errInputPrecedence- StateProvider after output",
Fn: func(int, state.Provider) {},
Err: errStateProviderPrecedence,
},
{
Name: "errInputPrecedence- input after output",
Fn: func(int, func(int), int) {},
Expand Down
17 changes: 9 additions & 8 deletions sdks/go/pkg/beam/core/graph/edge.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,15 @@ type MultiEdge struct {
parent *Scope

Op Opcode
DoFn *DoFn // ParDo
RestrictionCoder *coder.Coder // SplittableParDo
CombineFn *CombineFn // Combine
AccumCoder *coder.Coder // Combine
Value []byte // Impulse
External *ExternalTransform // Current External Transforms API
Payload *Payload // Legacy External Transforms API
WindowFn *window.Fn // WindowInto
DoFn *DoFn // ParDo
RestrictionCoder *coder.Coder // SplittableParDo
StateCoders map[string]*coder.Coder // Stateful ParDo
CombineFn *CombineFn // Combine
AccumCoder *coder.Coder // Combine
Value []byte // Impulse
External *ExternalTransform // Current External Transforms API
Payload *Payload // Legacy External Transforms API
WindowFn *window.Fn // WindowInto

Input []*Inbound
Output []*Outbound
Expand Down
75 changes: 74 additions & 1 deletion sdks/go/pkg/beam/core/graph/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/funcx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
Expand Down Expand Up @@ -282,6 +283,27 @@ func (f *DoFn) IsSplittable() bool {
return ok
}

// PipelineState returns a list of PipelineState objects used to access/mutate global pipeline state (if any).
func (f *DoFn) PipelineState() []state.PipelineState {
var s []state.PipelineState
if f.Recv == nil {
return s
}

v := reflect.Indirect(reflect.ValueOf(f.Recv))

for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
if f.CanInterface() {
if ps, ok := f.Interface().(state.PipelineState); ok {
s = append(s, ps)
}
}
}

return s
}

// SplittableDoFn represents a DoFn implementing SDF methods.
type SplittableDoFn DoFn

Expand Down Expand Up @@ -561,7 +583,14 @@ func AsDoFn(fn *Fn, numMainIn mainInputs) (*DoFn, error) {
}
}

return (*DoFn)(fn), nil
doFn := (*DoFn)(fn)

err = validateState(doFn, numMainIn)
if err != nil {
return nil, addContext(err, fn)
}

return doFn, nil
}

// validateMainInputs checks that a method has the given number of main inputs
Expand Down Expand Up @@ -1221,6 +1250,50 @@ func validateStatefulWatermarkSig(fn *Fn, numMainIn int) error {
return nil
}

func validateState(fn *DoFn, numIn mainInputs) error {
ps := fn.PipelineState()

if _, hasSp := fn.methods[processElementName].StateProvider(); hasSp {
if numIn == MainSingle {
err := errors.Errorf("ProcessElement uses a StateProvider, but is not keyed")
return errors.SetTopLevelMsgf(err, "ProcessElement uses a StateProvider, but is not keyed. "+
"All stateful DoFns must take a key/value pair as an input.")
}
if len(ps) == 0 {
err := errors.Errorf("ProcessElement uses a StateProvider, but noState structs are attached to the DoFn")
return errors.SetTopLevelMsgf(err, "ProcessElement uses a StateProvider, but no State structs are "+
"attached to the DoFn. Ensure that you are including the State structs you're using to read/write"+
"global state as public uppercase member variables.")
}
stateKeys := make(map[string]state.PipelineState)
for _, s := range ps {
k := s.StateKey()
if orig, ok := stateKeys[k]; ok {
err := errors.Errorf("Duplicate state key %v", k)
return errors.SetTopLevelMsgf(err, "Duplicate state key %v used by %v and %v. Ensure that state keys are"+
"unique per DoFn", k, orig, s)
} else {
stateKeys[k] = s
}
}

// TODO(#22736) - Remove this once state is fully supported
err := errors.Errorf("ProcessElement uses a StateProvider, but state is not supported in this release.")
return errors.SetTopLevelMsgf(err, "ProcessElement uses a StateProvider, but state is not supported in this release. "+
"Please try upgrading to a newer release if one exists or wait for state support to be released.")
} else {
if len(ps) > 0 {
err := errors.Errorf("ProcessElement doesn't use a StateProvider, but State structs are attached to "+
"the DoFn: %v", ps)
return errors.SetTopLevelMsgf(err, "ProcessElement doesn't use a StateProvider, but State structs are "+
"attached to the DoFn: %v\nEnsure that you are using the StateProvider to perform any reads or writes"+
"of pipeline state.", ps)
}
}

return nil
}

// CombineFn represents a CombineFn.
type CombineFn Fn

Expand Down
Loading

0 comments on commit 1645a2d

Please sign in to comment.