diff --git a/internal/internal_event_handlers.go b/internal/internal_event_handlers.go index d201601f3..57447d766 100644 --- a/internal/internal_event_handlers.go +++ b/internal/internal_event_handlers.go @@ -894,19 +894,13 @@ func (wc *workflowEnvironmentImpl) QueueUpdate(name string, f func()) { wc.bufferedUpdateRequests[name] = append(wc.bufferedUpdateRequests[name], f) } -func (wc *workflowEnvironmentImpl) HandleUpdates(name string) bool { - if !wc.sdkFlags.tryUse(SDKPriorityUpdateHandling, !wc.isReplay) { - return false - } - updatesHandled := false +func (wc *workflowEnvironmentImpl) HandleQueuedUpdates(name string) { if bufferedUpdateRequests, ok := wc.bufferedUpdateRequests[name]; ok { for _, request := range bufferedUpdateRequests { request() - updatesHandled = true } delete(wc.bufferedUpdateRequests, name) } - return updatesHandled } func (wc *workflowEnvironmentImpl) DrainUnhandledUpdates() bool { diff --git a/internal/internal_update_test.go b/internal/internal_update_test.go index 91f1607f6..e19c5478f 100644 --- a/internal/internal_update_test.go +++ b/internal/internal_update_test.go @@ -89,8 +89,6 @@ var testSDKFlags = newSDKFlags( ) func TestUpdateHandlerPanicHandling(t *testing.T) { - t.Parallel() - env := &workflowEnvironmentImpl{ sdkFlags: testSDKFlags, commandsHelper: newCommandsHelper(), @@ -100,29 +98,45 @@ func TestUpdateHandlerPanicHandling(t *testing.T) { TaskQueueName: "taskqueue:" + t.Name(), }, } - interceptor, ctx, err := newWorkflowContext(env, nil) - require.NoError(t, err) - dispatcher, ctx := newDispatcher( - ctx, - interceptor, - func(ctx Context) {}, - func() bool { return false }) - dispatcher.executing = true - - panicFunc := func() error { panic("intentional") } - mustSetUpdateHandler(t, ctx, t.Name(), panicFunc, UpdateHandlerOptions{Validator: panicFunc}) - in := UpdateInput{Name: t.Name(), Args: []interface{}{}} t.Run("ValidateUpdate", func(t *testing.T) { - err = interceptor.inboundInterceptor.ValidateUpdate(ctx, &in) + interceptor, ctx, err := newWorkflowContext(env, nil) + require.NoError(t, err) + + panicFunc := func() error { panic("intentional") } + dispatcher, _ := newDispatcher( + ctx, + interceptor, + func(ctx Context) { + mustSetUpdateHandler(t, ctx, t.Name(), panicFunc, UpdateHandlerOptions{Validator: panicFunc}) + in := UpdateInput{Name: t.Name(), Args: []interface{}{}} + err = interceptor.inboundInterceptor.ValidateUpdate(ctx, &in) + }, + func() bool { return false }) + + require.NoError(t, dispatcher.ExecuteUntilAllBlocked(10*time.Second)) var panicerr *PanicError require.ErrorAs(t, err, &panicerr, "panic during validate should be converted to an error to fail the update") }) t.Run("ExecuteUpdate", func(t *testing.T) { - require.Panics(t, func() { - _, _ = interceptor.inboundInterceptor.ExecuteUpdate(ctx, &in) - }, "panic during execution should be propagated to reach the WorkflowPanicPolicy") + interceptor, ctx, err := newWorkflowContext(env, nil) + require.NoError(t, err) + + panicFunc := func() error { panic("intentional") } + dispatcher, _ := newDispatcher( + ctx, + interceptor, + func(ctx Context) { + mustSetUpdateHandler(t, ctx, t.Name(), panicFunc, UpdateHandlerOptions{}) + in := UpdateInput{Name: t.Name(), Args: []interface{}{}} + err = interceptor.inboundInterceptor.ValidateUpdate(ctx, &in) + require.Panics(t, func() { + _, _ = interceptor.inboundInterceptor.ExecuteUpdate(ctx, &in) + }, "panic during execution should be propagated to reach the WorkflowPanicPolicy") + }, + func() bool { return false }) + require.NoError(t, dispatcher.ExecuteUntilAllBlocked(10*time.Second)) }) } @@ -171,25 +185,21 @@ func TestUpdateValidatorFnValidation(t *testing.T) { } func TestDefaultUpdateHandler(t *testing.T) { + t.Parallel() + dc := converter.GetDefaultDataConverter() - env := &workflowEnvironmentImpl{ - sdkFlags: testSDKFlags, - commandsHelper: newCommandsHelper(), - dataConverter: dc, - workflowInfo: &WorkflowInfo{ - Namespace: "namespace:" + t.Name(), - TaskQueueName: "taskqueue:" + t.Name(), - }, - bufferedUpdateRequests: make(map[string][]func()), + createTestWfEnv := func() *workflowEnvironmentImpl { + return &workflowEnvironmentImpl{ + sdkFlags: testSDKFlags, + commandsHelper: newCommandsHelper(), + dataConverter: dc, + workflowInfo: &WorkflowInfo{ + Namespace: "namespace:" + t.Name(), + TaskQueueName: "taskqueue:" + t.Name(), + }, + bufferedUpdateRequests: make(map[string][]func()), + } } - interceptor, ctx, err := newWorkflowContext(env, nil) - require.NoError(t, err) - dispatcher, ctx := newDispatcher( - ctx, - interceptor, - func(ctx Context) {}, - env.DrainUnhandledUpdates) - dispatcher.executing = true hdr := &commonpb.Header{Fields: map[string]*commonpb.Payload{}} argStr := t.Name() @@ -197,74 +207,126 @@ func TestDefaultUpdateHandler(t *testing.T) { require.NoError(t, err) t.Run("no handler registered", func(t *testing.T) { - mustSetUpdateHandler( - t, + env := createTestWfEnv() + interceptor, ctx, err := newWorkflowContext(env, nil) + require.NoError(t, err) + + dispatcher, ctx := newDispatcher( ctx, - "unused_handler", - func() error { panic("should not be called") }, - UpdateHandlerOptions{}, - ) + interceptor, + func(ctx Context) { + mustSetUpdateHandler( + t, + ctx, + "unused_handler", + func() error { panic("should not be called") }, + UpdateHandlerOptions{}, + ) + }, + env.DrainUnhandledUpdates) var rejectErr error defaultUpdateHandler(ctx, "will_not_be_found", "testID", args, hdr, &testUpdateCallbacks{ RejectImpl: func(err error) { rejectErr = err }, }, runOnCallingThread) + require.NoError(t, dispatcher.ExecuteUntilAllBlocked(10*time.Second)) require.ErrorContains(t, rejectErr, "unknown update") require.ErrorContains(t, rejectErr, "unused_handler", "handler not found error should include a list of the registered handlers") }) t.Run("malformed serialized input", func(t *testing.T) { - mustSetUpdateHandler( - t, + env := createTestWfEnv() + interceptor, ctx, err := newWorkflowContext(env, nil) + require.NoError(t, err) + + dispatcher, ctx := newDispatcher( ctx, - t.Name(), - func(Context, int) error { return nil }, - UpdateHandlerOptions{}, - ) + interceptor, + func(ctx Context) { + mustSetUpdateHandler( + t, + ctx, + t.Name(), + func(Context, int) error { return nil }, + UpdateHandlerOptions{}, + ) + }, + env.DrainUnhandledUpdates) + junkArgs := &commonpb.Payloads{Payloads: []*commonpb.Payload{&commonpb.Payload{}}} var rejectErr error defaultUpdateHandler(ctx, t.Name(), "testID", junkArgs, hdr, &testUpdateCallbacks{ RejectImpl: func(err error) { rejectErr = err }, }, runOnCallingThread) + require.NoError(t, dispatcher.ExecuteUntilAllBlocked(10*time.Second)) require.ErrorContains(t, rejectErr, "unable to decode") }) t.Run("reject from validator", func(t *testing.T) { + env := createTestWfEnv() + interceptor, ctx, err := newWorkflowContext(env, nil) + require.NoError(t, err) + updateFunc := func(Context, string) error { panic("should not get called") } validatorFunc := func(Context, string) error { return errors.New("expected") } - mustSetUpdateHandler( - t, + dispatcher, ctx := newDispatcher( ctx, - t.Name(), - updateFunc, - UpdateHandlerOptions{Validator: validatorFunc}, - ) + interceptor, + func(ctx Context) { + mustSetUpdateHandler( + t, + ctx, + t.Name(), + updateFunc, + UpdateHandlerOptions{Validator: validatorFunc}, + ) + }, + env.DrainUnhandledUpdates) var rejectErr error defaultUpdateHandler(ctx, t.Name(), "testID", args, hdr, &testUpdateCallbacks{ RejectImpl: func(err error) { rejectErr = err }, }, runOnCallingThread) + require.NoError(t, dispatcher.ExecuteUntilAllBlocked(10*time.Second)) require.Equal(t, validatorFunc(ctx, argStr), rejectErr) }) t.Run("illegal state panic from validator", func(t *testing.T) { + env := createTestWfEnv() + interceptor, ctx, err := newWorkflowContext(env, nil) + require.NoError(t, err) + updateFunc := func(Context, string) error { panic("should not get called") } validatorFunc := func(Context, string) error { panic(panicIllegalAccessCoroutineState) } - mustSetUpdateHandler( - t, + dispatcher, ctx := newDispatcher( ctx, - t.Name(), - updateFunc, - UpdateHandlerOptions{Validator: validatorFunc}, - ) - - require.Panics(t, func() { - defaultUpdateHandler(ctx, t.Name(), "testID", args, hdr, &testUpdateCallbacks{}, runOnCallingThread) - }) + interceptor, + func(ctx Context) { + mustSetUpdateHandler( + t, + ctx, + t.Name(), + updateFunc, + UpdateHandlerOptions{Validator: validatorFunc}, + ) + }, + env.DrainUnhandledUpdates) + defaultUpdateHandler(ctx, t.Name(), "testID", args, hdr, &testUpdateCallbacks{}, runOnCallingThread) + require.Error(t, dispatcher.ExecuteUntilAllBlocked(10*time.Second)) }) t.Run("error from update func", func(t *testing.T) { + env := createTestWfEnv() + interceptor, ctx, err := newWorkflowContext(env, nil) + require.NoError(t, err) + updateFunc := func(Context, string) error { return errors.New("expected") } - mustSetUpdateHandler(t, ctx, t.Name(), updateFunc, UpdateHandlerOptions{}) + dispatcher, ctx := newDispatcher( + ctx, + interceptor, + func(ctx Context) { + mustSetUpdateHandler(t, ctx, t.Name(), updateFunc, UpdateHandlerOptions{}) + }, + env.DrainUnhandledUpdates) var ( resultErr error accepted bool @@ -277,14 +339,26 @@ func TestDefaultUpdateHandler(t *testing.T) { result = success }, }, runOnCallingThread) + + require.NoError(t, dispatcher.ExecuteUntilAllBlocked(10*time.Second)) require.True(t, accepted) require.Equal(t, updateFunc(ctx, argStr), resultErr) require.Nil(t, result) }) t.Run("update success", func(t *testing.T) { + env := createTestWfEnv() + interceptor, ctx, err := newWorkflowContext(env, nil) + require.NoError(t, err) + updateFunc := func(ctx Context, s string) (string, error) { return s + " success!", nil } - mustSetUpdateHandler(t, ctx, t.Name(), updateFunc, UpdateHandlerOptions{}) + dispatcher, ctx := newDispatcher( + ctx, + interceptor, + func(ctx Context) { + mustSetUpdateHandler(t, ctx, t.Name(), updateFunc, UpdateHandlerOptions{}) + }, + env.DrainUnhandledUpdates) var ( resultErr error accepted bool @@ -297,6 +371,7 @@ func TestDefaultUpdateHandler(t *testing.T) { result = success }, }, runOnCallingThread) + require.NoError(t, dispatcher.ExecuteUntilAllBlocked(10*time.Second)) require.True(t, accepted) require.Nil(t, resultErr) @@ -305,6 +380,7 @@ func TestDefaultUpdateHandler(t *testing.T) { }) t.Run("update before handlers registered", func(t *testing.T) { + env := createTestWfEnv() // same test as above except that we don't set the update handler for // t.Name() until the first Yield. This emulates the situation where // there is an update in the first WFT of a workflow so the SDK needs to diff --git a/internal/internal_worker_base.go b/internal/internal_worker_base.go index f102ce94f..6750638de 100644 --- a/internal/internal_worker_base.go +++ b/internal/internal_worker_base.go @@ -126,9 +126,8 @@ type ( GetRegistry() *registry // QueueUpdate request of type name QueueUpdate(name string, f func()) - // HandleUpdates unblock all updates of type name - // returns true if any update was unblocked - HandleUpdates(name string) bool + // HandleQueuedUpdates unblocks all queued updates of type name + HandleQueuedUpdates(name string) // DrainUnhandledUpdates unblocks all updates, meant to be used to drain // all unhandled updates at the end of a workflow task // returns true if any update was unblocked diff --git a/internal/internal_workflow.go b/internal/internal_workflow.go index c1253c6a8..3344ef469 100644 --- a/internal/internal_workflow.go +++ b/internal/internal_workflow.go @@ -513,6 +513,7 @@ func (d *syncWorkflowDefinition) Execute(env WorkflowEnvironment, header *common // we are yielding. state := getState(d.rootCtx) state.yield("yield before executing to setup state") + state.unblocked() // TODO: @shreyassrivatsan - add workflow trace span here r.workflowResult, r.error = d.workflow.Execute(d.rootCtx, input) @@ -1516,8 +1517,11 @@ func setUpdateHandler(ctx Context, updateName string, handler interface{}, opts return err } getWorkflowEnvOptions(ctx).updateHandlers[updateName] = uh - if getWorkflowEnvironment(ctx).HandleUpdates(updateName) { - getState(ctx).yield("letting any updates waiting on a handler run") + if getWorkflowEnvironment(ctx).TryUse(SDKPriorityUpdateHandling) { + getWorkflowEnvironment(ctx).HandleQueuedUpdates(updateName) + state := getState(ctx) + defer state.unblocked() + state.yield("letting any updates waiting on a handler run") } return nil } diff --git a/internal/internal_workflow_testsuite.go b/internal/internal_workflow_testsuite.go index 3aa9d1c7e..17e8720d9 100644 --- a/internal/internal_workflow_testsuite.go +++ b/internal/internal_workflow_testsuite.go @@ -554,16 +554,13 @@ func (env *testWorkflowEnvironmentImpl) QueueUpdate(name string, f func()) { env.bufferedUpdateRequests[name] = append(env.bufferedUpdateRequests[name], f) } -func (env *testWorkflowEnvironmentImpl) HandleUpdates(name string) bool { - updatesHandled := false +func (env *testWorkflowEnvironmentImpl) HandleQueuedUpdates(name string) { if bufferedUpdateRequests, ok := env.bufferedUpdateRequests[name]; ok { for _, requests := range bufferedUpdateRequests { requests() - updatesHandled = true } delete(env.bufferedUpdateRequests, name) } - return updatesHandled } func (env *testWorkflowEnvironmentImpl) DrainUnhandledUpdates() bool { diff --git a/test/integration_test.go b/test/integration_test.go index ab48735cb..863ef97e7 100644 --- a/test/integration_test.go +++ b/test/integration_test.go @@ -1334,6 +1334,42 @@ func (ts *IntegrationTestSuite) TestUpdateInfo() { ts.NoError(run.Get(ctx, nil)) } +func (ts *IntegrationTestSuite) TestUpdateValidatorRejectedFirstWFT() { + ctx := context.Background() + wfOptions := ts.startWorkflowOptions("test-update-validator-rejected-first-wft") + // Add start delay to make sure the update is in the first WFT + wfOptions.StartDelay = time.Hour + run, err := ts.client.ExecuteWorkflow(ctx, + wfOptions, ts.workflows.UpdateWithValidatorWorkflow) + ts.Nil(err) + // Send a bad update request that will get rejected + handler, err := ts.client.UpdateWorkflow(ctx, run.GetID(), run.GetRunID(), "update", "") + ts.NoError(err) + err = handler.Get(ctx, nil) + ts.Error(err) + // complete workflow + ts.NoError(ts.client.SignalWorkflow(ctx, run.GetID(), run.GetRunID(), "finish", "finished")) + ts.NoError(run.Get(ctx, nil)) +} + +func (ts *IntegrationTestSuite) TestUpdateValidatorRejected() { + ctx := context.Background() + wfOptions := ts.startWorkflowOptions("test-update-validator-rejected") + run, err := ts.client.ExecuteWorkflow(ctx, + wfOptions, ts.workflows.UpdateWithValidatorWorkflow) + ts.Nil(err) + _, err = ts.client.QueryWorkflow(ctx, run.GetID(), run.GetRunID(), "__stack_trace") + ts.NoError(err) + // Send a bad update request that will get rejected + handler, err := ts.client.UpdateWorkflow(ctx, run.GetID(), run.GetRunID(), "update", "") + ts.NoError(err) + err = handler.Get(ctx, nil) + ts.Error(err) + // complete workflow + ts.NoError(ts.client.SignalWorkflow(ctx, run.GetID(), run.GetRunID(), "finish", "finished")) + ts.NoError(run.Get(ctx, nil)) +} + func (ts *IntegrationTestSuite) TestBasicSession() { var expected []string err := ts.executeWorkflow("test-basic-session", ts.workflows.BasicSession, &expected) diff --git a/test/workflow_test.go b/test/workflow_test.go index 0aa37cd96..ea26f1107 100644 --- a/test/workflow_test.go +++ b/test/workflow_test.go @@ -323,6 +323,43 @@ func (w *Workflows) UpdateInfoWorkflow(ctx workflow.Context) error { return nil } +func (w *Workflows) UpdateWithValidatorWorkflow(ctx workflow.Context) error { + workflow.Go(ctx, func(ctx workflow.Context) { + _ = workflow.Sleep(ctx, time.Minute) + }) + err := workflow.SetUpdateHandlerWithOptions(ctx, "update", func(ctx workflow.Context, id string) (string, error) { + ctx = workflow.WithActivityOptions(ctx, w.defaultActivityOptions()) + var activities *Activities + activityFut := workflow.ExecuteActivity(ctx, activities.Echo, 0, 0) + err := activityFut.Get(ctx, nil) + if err != nil { + return "", err + } + return id, nil + }, workflow.UpdateHandlerOptions{ + Validator: func(ctx workflow.Context, id string) error { + if id != "testID" { + return errors.New("invalid ID") + } + return nil + }, + }) + if err != nil { + return errors.New("failed to register update handler") + } + + ctx = workflow.WithActivityOptions(ctx, w.defaultActivityOptions()) + var activities *Activities + activityFut := workflow.ExecuteActivity(ctx, activities.Sleep, time.Second) + err = activityFut.Get(ctx, nil) + if err != nil { + return err + } + + workflow.GetSignalChannel(ctx, "finish").Receive(ctx, nil) + return nil +} + func (w *Workflows) ActivityHeartbeatWithRetry(ctx workflow.Context) (heartbeatCounts int, err error) { // Make retries fast opts := w.defaultActivityOptions() @@ -2556,6 +2593,7 @@ func (w *Workflows) register(worker worker.Worker) { worker.RegisterWorkflow(w.SessionFailedStateWorkflow) worker.RegisterWorkflow(w.VersionLoopWorkflow) worker.RegisterWorkflow(w.RaceOnCacheEviction) + worker.RegisterWorkflow(w.UpdateWithValidatorWorkflow) worker.RegisterWorkflow(w.child) worker.RegisterWorkflow(w.childForMemoAndSearchAttr)