diff --git a/.gitignore b/.gitignore index ff8a9b166d..9ccd65b879 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ vendor bin .DS_Store +_test diff --git a/pkg/controller/nodes/dynamic/handler.go b/pkg/controller/nodes/dynamic/handler.go index e5f3762816..4da147c3bf 100644 --- a/pkg/controller/nodes/dynamic/handler.go +++ b/pkg/controller/nodes/dynamic/handler.go @@ -191,11 +191,16 @@ func (d dynamicNodeTaskNodeHandler) Abort(ctx context.Context, nCtx handler.Node // This is a weird method. We should always finalize before we set the dynamic parent node phase as complete? func (d dynamicNodeTaskNodeHandler) Finalize(ctx context.Context, nCtx handler.NodeExecutionContext) error { + // We should always finalize the parent node success of failure. + // If we use the state to decide the finalize then we will never invoke the finalizer for the parent. + logger.Infof(ctx, "Finalizing Parent node") + if err := d.TaskNodeHandler.Finalize(ctx, nCtx); err != nil { + logger.Errorf(ctx, "Failed to finalize Dynamic Nodes Parent.") + return err + } + ds := nCtx.NodeStateReader().GetDynamicNodeState() - switch ds.Phase { - case v1alpha1.DynamicNodePhaseFailing: - fallthrough - case v1alpha1.DynamicNodePhaseExecuting: + if ds.Phase == v1alpha1.DynamicNodePhaseFailing || ds.Phase == v1alpha1.DynamicNodePhaseExecuting { logger.Infof(ctx, "Finalizing dynamic workflow") dynamicWF, isDynamic, err := d.buildContextualDynamicWorkflow(ctx, nCtx) if err != nil { @@ -205,12 +210,10 @@ func (d dynamicNodeTaskNodeHandler) Finalize(ctx context.Context, nCtx handler.N if !isDynamic { return nil } - return d.nodeExecutor.FinalizeHandler(ctx, dynamicWF, dynamicWF.StartNode()) - default: - logger.Infof(ctx, "Finalizing regular node") - return d.TaskNodeHandler.Finalize(ctx, nCtx) } + + return nil } func (d dynamicNodeTaskNodeHandler) buildDynamicWorkflowTemplate(ctx context.Context, djSpec *core.DynamicJobSpec, diff --git a/pkg/controller/nodes/dynamic/handler_test.go b/pkg/controller/nodes/dynamic/handler_test.go index b4edf9aaa4..934ca3cb5a 100644 --- a/pkg/controller/nodes/dynamic/handler_test.go +++ b/pkg/controller/nodes/dynamic/handler_test.go @@ -443,6 +443,160 @@ func Test_dynamicNodeHandler_Handle_SubTask(t *testing.T) { } } +func TestDynamicNodeTaskNodeHandler_Finalize(t *testing.T) { + ctx := context.TODO() + + t.Run("dynamicnodephase-none", func(t *testing.T) { + s := handler.DynamicNodeState{ + Phase: v1alpha1.DynamicNodePhaseNone, + Reason: "", + } + nCtx := &nodeMocks.NodeExecutionContext{} + sr := &nodeMocks.NodeStateReader{} + sr.OnGetDynamicNodeState().Return(s) + nCtx.OnNodeStateReader().Return(sr) + + h := &mocks.TaskNodeHandler{} + h.OnFinalize(ctx, nCtx).Return(nil) + n := &executorMocks.Node{} + d := New(h, n, promutils.NewTestScope()) + assert.NoError(t, d.Finalize(ctx, nCtx)) + assert.NotZero(t, len(h.ExpectedCalls)) + assert.Equal(t, "Finalize", h.ExpectedCalls[0].Method) + }) + + createNodeContext := func(ttype string, finalOutput storage.DataReference) *nodeMocks.NodeExecutionContext { + ctx := context.TODO() + + wfExecID := &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", + } + + nm := &nodeMocks.NodeExecutionMetadata{} + nm.On("GetAnnotations").Return(map[string]string{}) + nm.On("GetExecutionID").Return(v1alpha1.WorkflowExecutionIdentifier{ + WorkflowExecutionIdentifier: wfExecID, + }) + nm.On("GetK8sServiceAccount").Return("service-account") + nm.On("GetLabels").Return(map[string]string{}) + nm.On("GetNamespace").Return("namespace") + nm.On("GetOwnerID").Return(types.NamespacedName{Namespace: "namespace", Name: "name"}) + nm.On("GetOwnerReference").Return(v1.OwnerReference{ + Kind: "sample", + Name: "name", + }) + + taskID := &core.Identifier{} + tk := &core.TaskTemplate{ + Id: taskID, + Type: "test", + Metadata: &core.TaskMetadata{ + Discoverable: true, + }, + Interface: &core.TypedInterface{ + Outputs: &core.VariableMap{ + Variables: map[string]*core.Variable{ + "x": { + Type: &core.LiteralType{ + Type: &core.LiteralType_Simple{ + Simple: core.SimpleType_INTEGER, + }, + }, + }, + }, + }, + }, + } + tr := &nodeMocks.TaskReader{} + tr.On("GetTaskID").Return(taskID) + tr.On("GetTaskType").Return(ttype) + tr.On("Read", mock.Anything).Return(tk, nil) + + n := &flyteMocks.ExecutableNode{} + tID := "task-1" + n.On("GetTaskID").Return(&tID) + + dataStore, err := storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope()) + assert.NoError(t, err) + + ir := &ioMocks.InputReader{} + nCtx := &nodeMocks.NodeExecutionContext{} + nCtx.On("NodeExecutionMetadata").Return(nm) + nCtx.On("Node").Return(n) + nCtx.On("InputReader").Return(ir) + nCtx.On("DataReferenceConstructor").Return(storage.NewDataStore(&storage.Config{Type: storage.TypeMemory}, promutils.NewTestScope())) + nCtx.On("CurrentAttempt").Return(uint32(1)) + nCtx.On("TaskReader").Return(tr) + nCtx.On("MaxDatasetSizeBytes").Return(int64(1)) + nCtx.On("NodeID").Return("n1") + nCtx.On("EnqueueOwnerFunc").Return(func() error { return nil }) + nCtx.OnDataStore().Return(dataStore) + + endNodeStatus := &flyteMocks.ExecutableNodeStatus{} + endNodeStatus.On("GetDataDir").Return(storage.DataReference("end-node")) + endNodeStatus.On("GetOutputDir").Return(storage.DataReference("end-node")) + + subNs := &flyteMocks.ExecutableNodeStatus{} + subNs.On("SetDataDir", mock.Anything).Return() + subNs.On("SetOutputDir", mock.Anything).Return() + subNs.On("ResetDirty").Return() + subNs.On("GetOutputDir").Return(finalOutput) + subNs.On("SetParentTaskID", mock.Anything).Return() + subNs.OnGetAttempts().Return(0) + + dynamicNS := &flyteMocks.ExecutableNodeStatus{} + dynamicNS.On("SetDataDir", mock.Anything).Return() + dynamicNS.On("SetOutputDir", mock.Anything).Return() + dynamicNS.On("SetParentTaskID", mock.Anything).Return() + dynamicNS.OnGetNodeExecutionStatus(ctx, "n1-1-Node_1").Return(subNs) + dynamicNS.OnGetNodeExecutionStatus(ctx, "n1-1-Node_2").Return(subNs) + dynamicNS.OnGetNodeExecutionStatus(ctx, "n1-1-Node_3").Return(subNs) + dynamicNS.OnGetNodeExecutionStatus(ctx, v1alpha1.EndNodeID).Return(endNodeStatus) + + ns := &flyteMocks.ExecutableNodeStatus{} + ns.On("GetDataDir").Return(storage.DataReference("data-dir")) + ns.On("GetOutputDir").Return(storage.DataReference("output-dir")) + ns.On("GetNodeExecutionStatus", dynamicNodeID).Return(dynamicNS) + ns.OnGetNodeExecutionStatus(ctx, dynamicNodeID).Return(dynamicNS) + nCtx.On("NodeStatus").Return(ns) + + w := &flyteMocks.ExecutableWorkflow{} + ws := &flyteMocks.ExecutableWorkflowStatus{} + ws.OnGetNodeExecutionStatus(ctx, "n1").Return(ns) + w.On("GetExecutionStatus").Return(ws) + nCtx.On("Workflow").Return(w) + + r := &nodeMocks.NodeStateReader{} + r.On("GetDynamicNodeState").Return(handler.DynamicNodeState{ + Phase: v1alpha1.DynamicNodePhaseExecuting, + }) + nCtx.On("NodeStateReader").Return(r) + return nCtx + } + + t.Run("dynamicnodephase-executing", func(t *testing.T) { + + nCtx := createNodeContext("test", "x") + f, err := nCtx.DataStore().ConstructReference(context.TODO(), nCtx.NodeStatus().GetOutputDir(), "futures.pb") + assert.NoError(t, err) + dj := createDynamicJobSpec() + assert.NoError(t, nCtx.DataStore().WriteProtobuf(context.TODO(), f, storage.Options{}, dj)) + + h := &mocks.TaskNodeHandler{} + h.OnFinalize(ctx, nCtx).Return(nil) + n := &executorMocks.Node{} + n.OnFinalizeHandlerMatch(ctx, mock.Anything, mock.Anything).Return(nil) + d := New(h, n, promutils.NewTestScope()) + assert.NoError(t, d.Finalize(ctx, nCtx)) + assert.NotZero(t, len(h.ExpectedCalls)) + assert.Equal(t, "Finalize", h.ExpectedCalls[0].Method) + assert.NotZero(t, len(n.ExpectedCalls)) + assert.Equal(t, "FinalizeHandler", n.ExpectedCalls[0].Method) + }) +} + func init() { labeled.SetMetricKeys(contextutils.ProjectKey, contextutils.DomainKey, contextutils.WorkflowIDKey, contextutils.TaskIDKey) diff --git a/pkg/controller/nodes/task/handler.go b/pkg/controller/nodes/task/handler.go index 2f07e54f7c..010afe9ca7 100644 --- a/pkg/controller/nodes/task/handler.go +++ b/pkg/controller/nodes/task/handler.go @@ -573,7 +573,7 @@ func (t Handler) Finalize(ctx context.Context, nCtx handler.NodeExecutionContext if r := recover(); r != nil { t.metrics.pluginPanics.Inc(ctx) stack := debug.Stack() - logger.Errorf(ctx, "Panic in plugin.Abort for TaskType [%s]", tCtx.tr.GetTaskType()) + logger.Errorf(ctx, "Panic in plugin.Finalize for TaskType [%s]", tCtx.tr.GetTaskType()) err = fmt.Errorf("panic when executing a plugin for TaskType [%s]. Stack: [%s]", tCtx.tr.GetTaskType(), string(stack)) } }()