diff --git a/internal/internal_task_handlers.go b/internal/internal_task_handlers.go index 70b718f84..4d62620fd 100644 --- a/internal/internal_task_handlers.go +++ b/internal/internal_task_handlers.go @@ -115,8 +115,10 @@ type ( isWorkflowCompleted bool result *commonpb.Payloads err error - + // previousStartedEventID is the event ID of the workflow task started event of the previous workflow task. previousStartedEventID int64 + // lastHandledEventID is the event ID of the last event that the workflow state machine processed. + lastHandledEventID int64 newCommands []*commandpb.Command newMessages []*protocolpb.Message @@ -170,18 +172,19 @@ type ( // history wrapper method to help information about events. history struct { - workflowTask *workflowTask - eventsHandler *workflowExecutionEventHandlerImpl - loadedEvents []*historypb.HistoryEvent - currentIndex int - nextEventID int64 // next expected eventID for sanity - lastEventID int64 // last expected eventID, zero indicates read until end of stream - next []*historypb.HistoryEvent - nextMessages []*protocolpb.Message - nextFlags []sdkFlag - binaryChecksum string - sdkVersion string - sdkName string + workflowTask *workflowTask + eventsHandler *workflowExecutionEventHandlerImpl + loadedEvents []*historypb.HistoryEvent + currentIndex int + nextEventID int64 // next expected eventID for sanity + lastEventID int64 // last expected eventID, zero indicates read until end of stream + lastHandledEventID int64 // last event ID that was processed + next []*historypb.HistoryEvent + nextMessages []*protocolpb.Message + nextFlags []sdkFlag + binaryChecksum string + sdkVersion string + sdkName string } workflowTaskHeartbeatError struct { @@ -219,13 +222,14 @@ type ( } ) -func newHistory(task *workflowTask, eventsHandler *workflowExecutionEventHandlerImpl) *history { +func newHistory(lastHandledEventID int64, task *workflowTask, eventsHandler *workflowExecutionEventHandlerImpl) *history { result := &history{ - workflowTask: task, - eventsHandler: eventsHandler, - loadedEvents: task.task.History.Events, - currentIndex: 0, - lastEventID: task.task.GetStartedEventId(), + workflowTask: task, + eventsHandler: eventsHandler, + loadedEvents: task.task.History.Events, + currentIndex: 0, + lastEventID: task.task.GetStartedEventId(), + lastHandledEventID: lastHandledEventID, } if len(result.loadedEvents) > 0 { result.nextEventID = result.loadedEvents[0].GetEventId() @@ -454,6 +458,11 @@ OrderEvents: } eh.nextEventID++ + if eventID <= eh.lastHandledEventID { + eh.currentIndex++ + continue + } + eh.lastHandledEventID = eventID switch event.GetEventType() { case enumspb.EVENT_TYPE_WORKFLOW_TASK_STARTED: @@ -583,16 +592,15 @@ func (w *workflowExecutionContextImpl) Unlock(err error) { defer w.mutex.Unlock() if err != nil || w.err != nil || w.isWorkflowCompleted || (w.wth.cache.MaxWorkflowCacheSize() <= 0 && !w.hasPendingLocalActivityWork()) { - // TODO: in case of closed, it asumes the close command always succeed. need server side change to return + // TODO: in case of closed, it assumes the close command always succeed. need server side change to return // error to indicate the close failure case. This should be rare case. For now, always remove the cache, and // if the close command failed, the next command will have to rebuild the state. if w.wth.cache.getWorkflowCache().Exist(w.workflowInfo.WorkflowExecution.RunID) { w.wth.cache.removeWorkflowContext(w.workflowInfo.WorkflowExecution.RunID) w.cached = false - } else { - // sticky is disabled, manually clear the workflow state. - w.clearState() } + // Clear the state so other tasks waiting on the context know it should be discarded. + w.clearState() } else if !w.cached { // Clear the state if we never cached the workflow so coroutines can be // exited @@ -638,6 +646,7 @@ func (w *workflowExecutionContextImpl) clearState() { w.result = nil w.err = nil w.previousStartedEventID = 0 + w.lastHandledEventID = 0 w.newCommands = nil w.newMessages = nil @@ -755,10 +764,10 @@ func (wth *workflowTaskHandlerImpl) GetOrCreateWorkflowContext( // Verify the cached state is current and for the correct worker if workflowContext != nil { workflowContext.Lock() - if task.Query != nil && !isFullHistory && wth == workflowContext.wth { + if task.Query != nil && !isFullHistory && wth == workflowContext.wth && !workflowContext.IsDestroyed() { // query task and we have a valid cached state metricsHandler.Counter(metrics.StickyCacheHit).Inc(1) - } else if history.Events[0].GetEventId() == workflowContext.previousStartedEventID+1 && wth == workflowContext.wth { + } else if history.Events[0].GetEventId() == workflowContext.previousStartedEventID+1 && wth == workflowContext.wth && !workflowContext.IsDestroyed() { // non query task and we have a valid cached state metricsHandler.Counter(metrics.StickyCacheHit).Inc(1) } else { @@ -989,7 +998,14 @@ func (w *workflowExecutionContextImpl) ProcessWorkflowTask(workflowTask *workflo w.SetCurrentTask(task) eventHandler := w.getEventHandler() - reorderedHistory := newHistory(workflowTask, eventHandler) + reorderedHistory := newHistory(w.lastHandledEventID, workflowTask, eventHandler) + defer func() { + // After processing the workflow task, update the last handled event ID + // to the last event ID in the history. We do this regardless of whether the workflow task + // was successfully processed or not. This is because a failed workflow task will cause the + // cache to be evicted and the next workflow task will start from the beginning of the history. + w.lastHandledEventID = reorderedHistory.lastHandledEventID + }() var replayOutbox []outboxEntry var replayCommands []*commandpb.Command var respondEvents []*historypb.HistoryEvent @@ -1400,6 +1416,12 @@ func (w *workflowExecutionContextImpl) SetCurrentTask(task *workflowservice.Poll } func (w *workflowExecutionContextImpl) SetPreviousStartedEventID(eventID int64) { + // We must reset the last event we handled to be after the last WFT we really completed + // + any command events (since the SDK "processed" those when it emitted the commands). This + // is also equal to what we just processed in the speculative task, minus two, since we + // would've just handled the most recent WFT started event, and we need to drop that & the + // schedule event just before it. + w.lastHandledEventID = w.lastHandledEventID - 2 w.previousStartedEventID = eventID } diff --git a/internal/internal_task_handlers_interfaces_test.go b/internal/internal_task_handlers_interfaces_test.go index d020bf99a..78e15ccc8 100644 --- a/internal/internal_task_handlers_interfaces_test.go +++ b/internal/internal_task_handlers_interfaces_test.go @@ -187,7 +187,7 @@ func (s *PollLayerInterfacesTestSuite) TestGetNextCommands() { workflowTask := &workflowTask{task: task, historyIterator: historyIterator} - eh := newHistory(workflowTask, nil) + eh := newHistory(0, workflowTask, nil) nextTask, err := eh.nextTask() @@ -232,7 +232,7 @@ func (s *PollLayerInterfacesTestSuite) TestGetNextCommandsSdkFlags() { workflowTask := &workflowTask{task: task, historyIterator: historyIterator} - eh := newHistory(workflowTask, nil) + eh := newHistory(0, workflowTask, nil) nextTask, err := eh.nextTask() @@ -301,7 +301,7 @@ func (s *PollLayerInterfacesTestSuite) TestMessageCommands() { workflowTask := &workflowTask{task: task, historyIterator: historyIterator} - eh := newHistory(workflowTask, nil) + eh := newHistory(0, workflowTask, nil) nextTask, err := eh.nextTask() s.NoError(err) @@ -370,7 +370,7 @@ func (s *PollLayerInterfacesTestSuite) TestEmptyPages() { } workflowTask := &workflowTask{task: task, historyIterator: historyIterator} - eh := newHistory(workflowTask, nil) + eh := newHistory(0, workflowTask, nil) type result struct { events []*historypb.HistoryEvent diff --git a/internal/internal_task_handlers_test.go b/internal/internal_task_handlers_test.go index 518e27193..9cb5d605e 100644 --- a/internal/internal_task_handlers_test.go +++ b/internal/internal_task_handlers_test.go @@ -426,6 +426,15 @@ func createTestUpsertWorkflowSearchAttributesForChangeVersion(eventID int64, wor } } +func createTestProtocolMessageUpdateRequest(ID string, eventID int64, request *updatepb.Request) *protocolpb.Message { + return &protocolpb.Message{ + Id: uuid.New(), + ProtocolInstanceId: ID, + SequencingId: &protocolpb.Message_EventId{EventId: eventID}, + Body: protocol.MustMarshalAny(request), + } +} + func createWorkflowTask( events []*historypb.HistoryEvent, previousStartEventID int64, diff --git a/internal/internal_task_pollers_test.go b/internal/internal_task_pollers_test.go index 26692a133..d34a7e4f9 100644 --- a/internal/internal_task_pollers_test.go +++ b/internal/internal_task_pollers_test.go @@ -30,15 +30,19 @@ import ( "errors" "sync/atomic" "testing" + "time" "github.com/golang/mock/gomock" "github.com/stretchr/testify/require" commonpb "go.temporal.io/api/common/v1" historypb "go.temporal.io/api/history/v1" + protocolpb "go.temporal.io/api/protocol/v1" taskqueuepb "go.temporal.io/api/taskqueue/v1" + "go.temporal.io/api/update/v1" "go.temporal.io/api/workflowservice/v1" "go.temporal.io/api/workflowservicemock/v1" "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/durationpb" ) type countingTaskHandler struct { @@ -222,3 +226,154 @@ func TestWFTCorruption(t *testing.T) { // Workflow should not be in cache require.Nil(t, cache.getWorkflowContext(runID)) } + +func TestWFTReset(t *testing.T) { + cache := NewWorkerCache() + params := workerExecutionParameters{ + cache: cache, + } + ensureRequiredParams(¶ms) + wfType := commonpb.WorkflowType{Name: t.Name() + "-workflow-type"} + reg := newRegistry() + reg.RegisterWorkflowWithOptions(func(ctx Context) error { + _ = SetUpdateHandler(ctx, "update", func(ctx Context) error { + return nil + }, UpdateHandlerOptions{ + Validator: func(ctx Context) error { + return errors.New("rejecting for test") + }, + }) + _ = Sleep(ctx, time.Second) + return Sleep(ctx, time.Second) + }, RegisterWorkflowOptions{ + Name: wfType.Name, + }) + var ( + taskQueue = taskqueuepb.TaskQueue{Name: t.Name() + "task-queue"} + history0 = historypb.History{Events: []*historypb.HistoryEvent{ + createTestEventWorkflowExecutionStarted(1, &historypb.WorkflowExecutionStartedEventAttributes{ + TaskQueue: &taskQueue, + }), + createTestEventWorkflowTaskScheduled(2, &historypb.WorkflowTaskScheduledEventAttributes{ + TaskQueue: &taskQueue, + StartToCloseTimeout: &durationpb.Duration{Seconds: 10}, + Attempt: 1, + }), + createTestEventWorkflowTaskStarted(3), + createTestEventWorkflowTaskCompleted(4, &historypb.WorkflowTaskCompletedEventAttributes{ + ScheduledEventId: 2, + StartedEventId: 3, + }), + createTestEventTimerStarted(5, 5), + createTestEventWorkflowTaskScheduled(6, &historypb.WorkflowTaskScheduledEventAttributes{ + TaskQueue: &taskQueue, + StartToCloseTimeout: &durationpb.Duration{Seconds: 10}, + Attempt: 1, + }), + createTestEventWorkflowTaskStarted(7), + }} + messages = []*protocolpb.Message{ + createTestProtocolMessageUpdateRequest("test-update", 6, &update.Request{ + Meta: &update.Meta{ + UpdateId: "test-update", + }, + Input: &update.Input{ + Name: "update", + }, + }), + } + history1 = historypb.History{Events: []*historypb.HistoryEvent{ + createTestEventWorkflowTaskCompleted(4, &historypb.WorkflowTaskCompletedEventAttributes{ + ScheduledEventId: 2, + StartedEventId: 3, + }), + createTestEventTimerStarted(5, 5), + createTestEventWorkflowTaskScheduled(6, &historypb.WorkflowTaskScheduledEventAttributes{ + TaskQueue: &taskQueue, + StartToCloseTimeout: &durationpb.Duration{Seconds: 10}, + Attempt: 1, + }), + createTestEventWorkflowTaskStarted(7), + }} + history2 = historypb.History{Events: []*historypb.HistoryEvent{ + createTestEventWorkflowTaskCompleted(4, &historypb.WorkflowTaskCompletedEventAttributes{ + ScheduledEventId: 2, + StartedEventId: 3, + }), + createTestEventTimerStarted(5, 5), + createTestEventTimerFired(6, 5), + createTestEventWorkflowTaskScheduled(7, &historypb.WorkflowTaskScheduledEventAttributes{ + TaskQueue: &taskQueue, + StartToCloseTimeout: &durationpb.Duration{Seconds: 10}, + Attempt: 1, + }), + createTestEventWorkflowTaskStarted(8), + }} + runID = t.Name() + "-run-id" + wfID = t.Name() + "-workflow-id" + wfe = commonpb.WorkflowExecution{RunId: runID, WorkflowId: wfID} + ctrl = gomock.NewController(t) + client = workflowservicemock.NewMockWorkflowServiceClient(ctrl) + innerTaskHandler = newWorkflowTaskHandler(params, nil, reg) + taskHandler = &countingTaskHandler{WorkflowTaskHandler: innerTaskHandler} + contextManager = taskHandler + pollResp0 = workflowservice.PollWorkflowTaskQueueResponse{ + Attempt: 1, + WorkflowExecution: &wfe, + WorkflowType: &wfType, + History: &history0, + Messages: messages, + PreviousStartedEventId: 3, + } + task0 = workflowTask{task: &pollResp0} + pollResp1 = workflowservice.PollWorkflowTaskQueueResponse{ + Attempt: 1, + WorkflowExecution: &wfe, + WorkflowType: &wfType, + History: &history1, + PreviousStartedEventId: 3, + } + task1 = workflowTask{task: &pollResp1} + pollResp2 = workflowservice.PollWorkflowTaskQueueResponse{ + Attempt: 1, + WorkflowExecution: &wfe, + WorkflowType: &wfType, + History: &history2, + PreviousStartedEventId: 3, + } + task2 = workflowTask{task: &pollResp2} + ) + + // Return a workflow task to reset the workflow to a previous state + client.EXPECT().RespondWorkflowTaskCompleted(gomock.Any(), gomock.Any()). + Return(&workflowservice.RespondWorkflowTaskCompletedResponse{ + ResetHistoryEventId: 3, + }, nil).Times(3) + // Return a workflow task to complete the workflow + client.EXPECT().RespondWorkflowTaskCompleted(gomock.Any(), gomock.Any()). + Return(&workflowservice.RespondWorkflowTaskCompletedResponse{}, nil) + + poller := newWorkflowTaskPoller(taskHandler, contextManager, client, params) + // Send a full history as part of the speculative WFT + require.NoError(t, poller.processWorkflowTask(&task0)) + originalCachedExecution := cache.getWorkflowContext(runID) + require.NotNil(t, originalCachedExecution) + require.Equal(t, int64(3), originalCachedExecution.previousStartedEventID) + require.Equal(t, int64(5), originalCachedExecution.lastHandledEventID) + // Send some fake speculative WFTs to ensure the workflow is reset properly + require.NoError(t, poller.processWorkflowTask(&task1)) + cachedExecution := cache.getWorkflowContext(runID) + require.True(t, originalCachedExecution == cachedExecution) + require.Equal(t, int64(3), cachedExecution.previousStartedEventID) + require.Equal(t, int64(5), cachedExecution.lastHandledEventID) + require.NoError(t, poller.processWorkflowTask(&task1)) + cachedExecution = cache.getWorkflowContext(runID) + // Check the cached execution is the same as the original + require.True(t, originalCachedExecution == cachedExecution) + require.Equal(t, int64(3), cachedExecution.previousStartedEventID) + require.Equal(t, int64(5), cachedExecution.lastHandledEventID) + // Send a real WFT with new events + require.NoError(t, poller.processWorkflowTask(&task2)) + cachedExecution = cache.getWorkflowContext(runID) + require.True(t, originalCachedExecution == cachedExecution) +}