Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Go stateful DoFns user side changes #22761

Merged
merged 4 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion sdks/go/pkg/beam/core/funcx/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"reflect"

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
Expand Down Expand Up @@ -82,6 +83,8 @@ const (
FnBundleFinalization FnParamKind = 0x800
// FnWatermarkEstimator indicates a function input parameter that implements sdf.WatermarkEstimator
FnWatermarkEstimator FnParamKind = 0x1000
// FnState indicates a function input parameter that implements state.Provider
FnStateProvider FnParamKind = 0x2000
)

func (k FnParamKind) String() string {
Expand Down Expand Up @@ -112,6 +115,8 @@ func (k FnParamKind) String() string {
return "BundleFinalization"
case FnWatermarkEstimator:
return "WatermarkEstimator"
case FnStateProvider:
return "StateProvider"
default:
return fmt.Sprintf("%v", int(k))
}
Expand Down Expand Up @@ -289,6 +294,17 @@ func (u *Fn) BundleFinalization() (pos int, exists bool) {
return -1, false
}

// StateProvider returns (index, true) iff the function expects a
// parameter that implements state.Provider.
func (u *Fn) StateProvider() (pos int, exists bool) {
for i, p := range u.Param {
if p.Kind == FnStateProvider {
return i, true
}
}
return -1, false
}

// WatermarkEstimator returns (index, true) iff the function expects a
// parameter that implements sdf.WatermarkEstimator.
func (u *Fn) WatermarkEstimator() (pos int, exists bool) {
Expand Down Expand Up @@ -374,6 +390,8 @@ func New(fn reflectx.Func) (*Fn, error) {
kind = FnWindow
case t == typex.BundleFinalizationType:
kind = FnBundleFinalization
case t == state.ProviderType:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we have the types for state and timers defined in typex package? I was thinking this for timers as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably vote we leave it as is unless there's a reason to switch - the typex package generally only exports types for structs it defines and IMO it makes sense to stay it in the state package to keep consistent with that. I don't feel super strongly though (@lostluck might have opinions too)

kind = FnStateProvider
case t == reflectx.Type:
kind = FnType
case t.Implements(reflect.TypeOf((*sdf.RTracker)(nil)).Elem()):
Expand Down Expand Up @@ -464,7 +482,7 @@ func SubReturns(list []ReturnParam, indices ...int) []ReturnParam {
}

// The order of present parameters and return values must be as follows:
// func(FnContext?, FnPane?, FnWindow?, FnEventTime?, FnWatermarkEstimator?, FnType?, FnBundleFinalization?, FnRTracker?, (FnValue, SideInput*)?, FnEmit*) (RetEventTime?, RetOutput?, RetError?)
// func(FnContext?, FnPane?, FnWindow?, FnEventTime?, FnWatermarkEstimator?, FnType?, FnBundleFinalization?, FnRTracker?, FnStateProvider?, (FnValue, SideInput*)?, FnEmit*) (RetEventTime?, RetOutput?, RetError?)
// where ? indicates 0 or 1, and * indicates any number.
// and a SideInput is one of FnValue or FnIter or FnReIter
// Note: Fns with inputs must have at least one FnValue as the main input.
Expand Down Expand Up @@ -496,6 +514,7 @@ var (
errReflectTypePrecedence = errors.New("may only have a single reflect.Type parameter and it must precede the main input parameter")
errRTrackerPrecedence = errors.New("may only have a single sdf.RTracker parameter and it must precede the main input parameter")
errBundleFinalizationPrecedence = errors.New("may only have a single BundleFinalization parameter and it must precede the main input parameter")
errStateProviderPrecedence = errors.New("may only have a single state.Provider parameter and it must precede the main input parameter")
errInputPrecedence = errors.New("inputs parameters must precede emit function parameters")
)

Expand All @@ -513,6 +532,7 @@ const (
psOutput
psRTracker
psBundleFinalization
psStateProvider
)

func nextParamState(cur paramState, transition FnParamKind) (paramState, error) {
Expand All @@ -535,6 +555,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psContext:
switch transition {
Expand All @@ -552,6 +574,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psPane:
switch transition {
Expand All @@ -567,6 +591,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psWindow:
switch transition {
Expand All @@ -580,6 +606,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psEventTime:
switch transition {
Expand All @@ -591,6 +619,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psWatermarkEstimator:
switch transition {
Expand All @@ -600,20 +630,31 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psType:
switch transition {
case FnBundleFinalization:
return psBundleFinalization, nil
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psBundleFinalization:
switch transition {
case FnRTracker:
return psRTracker, nil
case FnStateProvider:
return psStateProvider, nil
}
case psRTracker:
switch transition {
case FnStateProvider:
return psStateProvider, nil
}
case psStateProvider:
// Completely handled by the default clause
case psInput:
switch transition {
Expand Down Expand Up @@ -644,6 +685,8 @@ func nextParamState(cur paramState, transition FnParamKind) (paramState, error)
return -1, errBundleFinalizationPrecedence
case FnRTracker:
return -1, errRTrackerPrecedence
case FnStateProvider:
return -1, errStateProviderPrecedence
case FnIter, FnReIter, FnValue, FnMultiMap:
return psInput, nil
case FnEmit:
Expand Down
16 changes: 16 additions & 0 deletions sdks/go/pkg/beam/core/funcx/fn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/graph/mtime"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
)
Expand Down Expand Up @@ -108,6 +109,11 @@ func TestNew(t *testing.T) {
Fn: func(typex.PaneInfo, typex.Window, typex.EventTime, sdf.WatermarkEstimator, reflect.Type, []byte) {},
Param: []FnParamKind{FnPane, FnWindow, FnEventTime, FnWatermarkEstimator, FnType, FnValue},
},
{
Name: "good10",
Fn: func(sdf.RTracker, state.Provider, []byte) {},
Param: []FnParamKind{FnRTracker, FnStateProvider, FnValue},
},
{
Name: "good-method",
Fn: foo{1}.Do,
Expand Down Expand Up @@ -211,6 +217,16 @@ func TestNew(t *testing.T) {
Fn: func(int, func(int), func() func(*int) bool) {},
Err: errInputPrecedence,
},
{
Name: "errInputPrecedence- StateProvider before RTracker",
Fn: func(state.Provider, sdf.RTracker, int) {},
Err: errRTrackerPrecedence,
},
{
Name: "errInputPrecedence- StateProvider after output",
Fn: func(int, state.Provider) {},
Err: errStateProviderPrecedence,
},
{
Name: "errInputPrecedence- input after output",
Fn: func(int, func(int), int) {},
Expand Down
17 changes: 9 additions & 8 deletions sdks/go/pkg/beam/core/graph/edge.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,15 @@ type MultiEdge struct {
parent *Scope

Op Opcode
DoFn *DoFn // ParDo
RestrictionCoder *coder.Coder // SplittableParDo
CombineFn *CombineFn // Combine
AccumCoder *coder.Coder // Combine
Value []byte // Impulse
External *ExternalTransform // Current External Transforms API
Payload *Payload // Legacy External Transforms API
WindowFn *window.Fn // WindowInto
DoFn *DoFn // ParDo
RestrictionCoder *coder.Coder // SplittableParDo
StateCoders map[string]*coder.Coder // Stateful ParDo
CombineFn *CombineFn // Combine
AccumCoder *coder.Coder // Combine
Value []byte // Impulse
External *ExternalTransform // Current External Transforms API
Payload *Payload // Legacy External Transforms API
WindowFn *window.Fn // WindowInto

Input []*Inbound
Output []*Outbound
Expand Down
75 changes: 74 additions & 1 deletion sdks/go/pkg/beam/core/graph/fn.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (

"github.com/apache/beam/sdks/v2/go/pkg/beam/core/funcx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/sdf"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/state"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/typex"
"github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx"
"github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors"
Expand Down Expand Up @@ -282,6 +283,27 @@ func (f *DoFn) IsSplittable() bool {
return ok
}

// PipelineState returns a list of PipelineState objects used to access/mutate global pipeline state (if any).
func (f *DoFn) PipelineState() []state.PipelineState {
var s []state.PipelineState
if f.Recv == nil {
return s
}

v := reflect.Indirect(reflect.ValueOf(f.Recv))

for i := 0; i < v.NumField(); i++ {
f := v.Field(i)
if f.CanInterface() {
if ps, ok := f.Interface().(state.PipelineState); ok {
s = append(s, ps)
}
}
}

return s
}

// SplittableDoFn represents a DoFn implementing SDF methods.
type SplittableDoFn DoFn

Expand Down Expand Up @@ -561,7 +583,14 @@ func AsDoFn(fn *Fn, numMainIn mainInputs) (*DoFn, error) {
}
}

return (*DoFn)(fn), nil
doFn := (*DoFn)(fn)

err = validateState(doFn, numMainIn)
if err != nil {
return nil, addContext(err, fn)
}

return doFn, nil
}

// validateMainInputs checks that a method has the given number of main inputs
Expand Down Expand Up @@ -1221,6 +1250,50 @@ func validateStatefulWatermarkSig(fn *Fn, numMainIn int) error {
return nil
}

func validateState(fn *DoFn, numIn mainInputs) error {
ps := fn.PipelineState()

if _, hasSp := fn.methods[processElementName].StateProvider(); hasSp {
if numIn == MainSingle {
err := errors.Errorf("ProcessElement uses a StateProvider, but is not keyed")
return errors.SetTopLevelMsgf(err, "ProcessElement uses a StateProvider, but is not keyed. "+
"All stateful DoFns must take a key/value pair as an input.")
}
if len(ps) == 0 {
err := errors.Errorf("ProcessElement uses a StateProvider, but noState structs are attached to the DoFn")
return errors.SetTopLevelMsgf(err, "ProcessElement uses a StateProvider, but no State structs are "+
"attached to the DoFn. Ensure that you are including the State structs you're using to read/write"+
"global state as public uppercase member variables.")
}
stateKeys := make(map[string]state.PipelineState)
for _, s := range ps {
k := s.StateKey()
if orig, ok := stateKeys[k]; ok {
err := errors.Errorf("Duplicate state key %v", k)
return errors.SetTopLevelMsgf(err, "Duplicate state key %v used by %v and %v. Ensure that state keys are"+
"unique per DoFn", k, orig, s)
} else {
stateKeys[k] = s
}
}

// TODO(#22736) - Remove this once state is fully supported
err := errors.Errorf("ProcessElement uses a StateProvider, but state is not supported in this release.")
return errors.SetTopLevelMsgf(err, "ProcessElement uses a StateProvider, but state is not supported in this release. "+
"Please try upgrading to a newer release if one exists or wait for state support to be released.")
} else {
if len(ps) > 0 {
err := errors.Errorf("ProcessElement doesn't use a StateProvider, but State structs are attached to "+
"the DoFn: %v", ps)
return errors.SetTopLevelMsgf(err, "ProcessElement doesn't use a StateProvider, but State structs are "+
"attached to the DoFn: %v\nEnsure that you are using the StateProvider to perform any reads or writes"+
"of pipeline state.", ps)
}
}

return nil
}

// CombineFn represents a CombineFn.
type CombineFn Fn

Expand Down
Loading