diff --git a/pkg/manager/impl/node_execution_manager_test.go b/pkg/manager/impl/node_execution_manager_test.go index c6b4edb27..a1c43c36b 100644 --- a/pkg/manager/impl/node_execution_manager_test.go +++ b/pkg/manager/impl/node_execution_manager_test.go @@ -136,6 +136,9 @@ func TestCreateNodeEvent(t *testing.T) { StartedAt: occurredAtProto, CreatedAt: occurredAtProto, UpdatedAt: occurredAtProto, + TargetMetadata: &admin.NodeExecutionClosure_TaskNodeMetadata{ + TaskNodeMetadata: &admin.TaskNodeMetadata{}, + }, } closureBytes, _ := proto.Marshal(&expectedClosure) repository.NodeExecutionRepo().(*repositoryMocks.MockNodeExecutionRepo).SetCreateCallback( diff --git a/pkg/repositories/transformers/node_execution.go b/pkg/repositories/transformers/node_execution.go index a427412e3..f1d90361f 100644 --- a/pkg/repositories/transformers/node_execution.go +++ b/pkg/repositories/transformers/node_execution.go @@ -143,12 +143,30 @@ func CreateNodeExecutionModel(ctx context.Context, input ToNodeExecutionModelInp return nil, err } } + if common.IsNodeExecutionTerminal(input.Request.Event.Phase) { err := addTerminalState(ctx, input.Request, nodeExecution, &closure, input.InlineEventDataPolicy, input.StorageClient) if err != nil { return nil, err } } + + // Update TaskNodeMetadata, which includes caching information today. + if input.Request.Event.GetTaskNodeMetadata() != nil { + targetMetadata := &admin.NodeExecutionClosure_TaskNodeMetadata{ + TaskNodeMetadata: &admin.TaskNodeMetadata{ + CheckpointUri: input.Request.Event.GetTaskNodeMetadata().CheckpointUri, + }, + } + if input.Request.Event.GetTaskNodeMetadata().CatalogKey != nil { + st := input.Request.Event.GetTaskNodeMetadata().GetCacheStatus().String() + targetMetadata.TaskNodeMetadata.CacheStatus = input.Request.Event.GetTaskNodeMetadata().GetCacheStatus() + targetMetadata.TaskNodeMetadata.CatalogKey = input.Request.Event.GetTaskNodeMetadata().GetCatalogKey() + nodeExecution.CacheStatus = &st + } + closure.TargetMetadata = targetMetadata + } + marshaledClosure, err := proto.Marshal(&closure) if err != nil { return nil, errors.NewFlyteAdminErrorf( diff --git a/pkg/repositories/transformers/node_execution_test.go b/pkg/repositories/transformers/node_execution_test.go index 268e63c97..88ef8fc26 100644 --- a/pkg/repositories/transformers/node_execution_test.go +++ b/pkg/repositories/transformers/node_execution_test.go @@ -199,35 +199,51 @@ func TestAddTerminalState_Error(t *testing.T) { func TestCreateNodeExecutionModel(t *testing.T) { parentTaskExecID := uint(8) - nodeExecutionModel, err := CreateNodeExecutionModel(context.TODO(), ToNodeExecutionModelInput{ - Request: &admin.NodeExecutionEventRequest{ - Event: &event.NodeExecutionEvent{ - Id: &core.NodeExecutionIdentifier{ - NodeId: "node id", - ExecutionId: &core.WorkflowExecutionIdentifier{ - Project: "project", - Domain: "domain", - Name: "name", - }, - }, - Phase: core.NodeExecution_RUNNING, - InputValue: &event.NodeExecutionEvent_InputUri{ - InputUri: testInputURI, - }, - OutputResult: &event.NodeExecutionEvent_OutputUri{ - OutputUri: "output uri", + request := &admin.NodeExecutionEventRequest{ + Event: &event.NodeExecutionEvent{ + Id: &core.NodeExecutionIdentifier{ + NodeId: "node id", + ExecutionId: &core.WorkflowExecutionIdentifier{ + Project: "project", + Domain: "domain", + Name: "name", }, - OccurredAt: occurredAtProto, - ParentTaskMetadata: &event.ParentTaskExecutionMetadata{ - Id: &core.TaskExecutionIdentifier{ - RetryAttempt: 1, + }, + Phase: core.NodeExecution_RUNNING, + InputValue: &event.NodeExecutionEvent_InputUri{ + InputUri: testInputURI, + }, + OutputResult: &event.NodeExecutionEvent_OutputUri{ + OutputUri: "output uri", + }, + OccurredAt: occurredAtProto, + TargetMetadata: &event.NodeExecutionEvent_TaskNodeMetadata{ + TaskNodeMetadata: &event.TaskNodeMetadata{ + CacheStatus: core.CatalogCacheStatus_CACHE_POPULATED, + CatalogKey: &core.CatalogMetadata{ + DatasetId: &core.Identifier{ + ResourceType: core.ResourceType_DATASET, + Name: "x", + Project: "proj", + Domain: "domain", + }, }, + CheckpointUri: "last checkpoint uri", }, - IsParent: true, - IsDynamic: true, - EventVersion: 2, }, + ParentTaskMetadata: &event.ParentTaskExecutionMetadata{ + Id: &core.TaskExecutionIdentifier{ + RetryAttempt: 1, + }, + }, + IsParent: true, + IsDynamic: true, + EventVersion: 2, }, + } + + nodeExecutionModel, err := CreateNodeExecutionModel(context.TODO(), ToNodeExecutionModelInput{ + Request: request, ParentTaskExecutionID: &parentTaskExecID, }) assert.Nil(t, err) @@ -237,6 +253,13 @@ func TestCreateNodeExecutionModel(t *testing.T) { StartedAt: occurredAtProto, CreatedAt: occurredAtProto, UpdatedAt: occurredAtProto, + TargetMetadata: &admin.NodeExecutionClosure_TaskNodeMetadata{ + TaskNodeMetadata: &admin.TaskNodeMetadata{ + CacheStatus: request.Event.GetTaskNodeMetadata().CacheStatus, + CatalogKey: request.Event.GetTaskNodeMetadata().CatalogKey, + CheckpointUri: request.Event.GetTaskNodeMetadata().CheckpointUri, + }, + }, } var closureBytes, _ = proto.Marshal(closure) var nodeExecutionMetadata, _ = proto.Marshal(&admin.NodeExecutionMetaData{ @@ -247,6 +270,7 @@ func TestCreateNodeExecutionModel(t *testing.T) { EventVersion: 2, } internalDataBytes, _ := proto.Marshal(internalData) + cacheStatus := request.Event.GetTaskNodeMetadata().CacheStatus.String() assert.Equal(t, &models.NodeExecution{ NodeExecutionKey: models.NodeExecutionKey{ NodeID: "node id", @@ -264,6 +288,7 @@ func TestCreateNodeExecutionModel(t *testing.T) { NodeExecutionUpdatedAt: &occurredAt, NodeExecutionMetadata: nodeExecutionMetadata, ParentTaskExecutionID: &parentTaskExecID, + CacheStatus: &cacheStatus, InternalData: internalDataBytes, }, nodeExecutionModel) }