Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
Adding support for environment variables set on execution (#344)
Browse files Browse the repository at this point in the history
* added environment variables to TaskExecutionMetadata

Signed-off-by: Daniel Rammer <daniel@union.ai>

* added support for environment variables

Signed-off-by: Daniel Rammer <daniel@union.ai>

* implemented unit tests and fixed linter

Signed-off-by: Daniel Rammer <daniel@union.ai>

---------

Signed-off-by: Daniel Rammer <daniel@union.ai>
  • Loading branch information
hamersaw authored May 3, 2023
1 parent 63e1e45 commit 1e0faab
Show file tree
Hide file tree
Showing 21 changed files with 113 additions and 8 deletions.
1 change: 1 addition & 0 deletions go/tasks/pluginmachinery/core/exec_metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,5 @@ type TaskExecutionMetadata interface {
IsInterruptible() bool
GetPlatformResources() *v1.ResourceRequirements
GetInterruptibleFailureThreshold() uint32
GetEnvironmentVariables() map[string]string
}
34 changes: 34 additions & 0 deletions go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion go/tasks/pluginmachinery/flytek8s/container_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ func AddFlyteCustomizationsToContainer(ctx context.Context, parameters template.
}
container.Args = modifiedArgs

container.Env = DecorateEnvVars(ctx, container.Env, parameters.TaskExecMetadata.GetTaskExecutionID())
container.Env = DecorateEnvVars(ctx, container.Env, parameters.TaskExecMetadata.GetEnvironmentVariables(), parameters.TaskExecMetadata.GetTaskExecutionID())

if parameters.TaskExecMetadata.GetOverrides() != nil && parameters.TaskExecMetadata.GetOverrides().GetResources() != nil {
res := parameters.TaskExecMetadata.GetOverrides().GetResources()
Expand Down
8 changes: 8 additions & 0 deletions go/tasks/pluginmachinery/flytek8s/container_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ func TestToK8sContainer(t *testing.T) {
mockTaskExecutionID.OnGetGeneratedName().Return("gen_name")
mockTaskExecMetadata.OnGetTaskExecutionID().Return(&mockTaskExecutionID)
mockTaskExecMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
mockTaskExecMetadata.OnGetEnvironmentVariables().Return(map[string]string{
"foo": "bar",
})

tCtx := &mocks.TaskExecutionContext{}
tCtx.OnTaskExecutionMetadata().Return(&mockTaskExecMetadata)
Expand Down Expand Up @@ -419,6 +422,10 @@ func TestToK8sContainer(t *testing.T) {
Name: "k",
Value: "v",
},
{
Name: "foo",
Value: "bar",
},
}, container.Env)
errs := validation.IsDNS1123Label(container.Name)
assert.Nil(t, errs)
Expand Down Expand Up @@ -454,6 +461,7 @@ func getTemplateParametersForTest(resourceRequirements, platformResources *v1.Re
mockOverrides.OnGetResources().Return(resourceRequirements)
mockTaskExecMetadata.OnGetOverrides().Return(&mockOverrides)
mockTaskExecMetadata.OnGetPlatformResources().Return(platformResources)
mockTaskExecMetadata.OnGetEnvironmentVariables().Return(nil)

mockInputReader := mocks2.InputReader{}
mockInputPath := storage.DataReference("s3://input/path")
Expand Down
5 changes: 4 additions & 1 deletion go/tasks/pluginmachinery/flytek8s/k8s_resource_adds.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,13 @@ func GetExecutionEnvVars(id pluginsCore.TaskExecutionID) []v1.EnvVar {
return envVars
}

func DecorateEnvVars(ctx context.Context, envVars []v1.EnvVar, id pluginsCore.TaskExecutionID) []v1.EnvVar {
func DecorateEnvVars(ctx context.Context, envVars []v1.EnvVar, taskEnvironmentVariables map[string]string, id pluginsCore.TaskExecutionID) []v1.EnvVar {
envVars = append(envVars, GetContextEnvVars(ctx)...)
envVars = append(envVars, GetExecutionEnvVars(id)...)

for k, v := range taskEnvironmentVariables {
envVars = append(envVars, v1.EnvVar{Name: k, Value: v})
}
for k, v := range config.GetK8sPluginConfig().DefaultEnvVars {
envVars = append(envVars, v1.EnvVar{Name: k, Value: v})
}
Expand Down
10 changes: 6 additions & 4 deletions go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -266,19 +266,21 @@ func TestDecorateEnvVars(t *testing.T) {
args args
additionEnvVar map[string]string
additionEnvVarFromEnv map[string]string
executionEnvVar map[string]string
want []v12.EnvVar
}{
{"no-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, expected},
{"with-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, additionalEnv, emptyEnvVar, aggregated},
{"from-env", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, envVarsFromEnv, aggregated},
{"no-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, emptyEnvVar, expected},
{"with-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, additionalEnv, emptyEnvVar, emptyEnvVar, aggregated},
{"from-env", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, envVarsFromEnv, emptyEnvVar, aggregated},
{"from-execution-metadata", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, additionalEnv, aggregated},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.NoError(t, config.SetK8sPluginConfig(&config.K8sPluginConfig{
DefaultEnvVars: tt.additionEnvVar,
DefaultEnvVarsFromEnv: tt.additionEnvVarFromEnv,
}))
if got := DecorateEnvVars(ctx, tt.args.envVars, tt.args.id); !reflect.DeepEqual(got, tt.want) {
if got := DecorateEnvVars(ctx, tt.args.envVars, tt.executionEnvVar, tt.args.id); !reflect.DeepEqual(got, tt.want) {
t.Errorf("DecorateEnvVars() = %v, want %v", got, tt.want)
}
})
Expand Down
1 change: 1 addition & 0 deletions go/tasks/pluginmachinery/flytek8s/pod_helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ func dummyTaskExecutionMetadata(resources *v1.ResourceRequirements) pluginsCore.
taskExecutionMetadata.On("GetOverrides").Return(to)
taskExecutionMetadata.On("IsInterruptible").Return(true)
taskExecutionMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
return taskExecutionMetadata
}

Expand Down
2 changes: 1 addition & 1 deletion go/tasks/plugins/array/awsbatch/transformer.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ func UpdateBatchInputForArray(_ context.Context, batchInput *batch.SubmitJobInpu

func getEnvVarsForTask(ctx context.Context, execID pluginCore.TaskExecutionID, containerEnvVars []*core.KeyValuePair,
defaultEnvVars map[string]string) []v1.EnvVar {
envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(containerEnvVars), execID)
envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(containerEnvVars), nil, execID)
m := make(map[string]string, len(envVars))
for _, envVar := range envVars {
m[envVar.Name] = envVar.Value
Expand Down
1 change: 1 addition & 0 deletions go/tasks/plugins/array/k8s/management_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.Ta
tMeta.OnGetOwnerReference().Return(metav1.OwnerReference{})
tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
tMeta.OnGetInterruptibleFailureThreshold().Return(2)
tMeta.OnGetEnvironmentVariables().Return(nil)

ow := &mocks2.OutputWriter{}
ow.OnGetOutputPrefixPath().Return("/prefix/")
Expand Down
1 change: 1 addition & 0 deletions go/tasks/plugins/k8s/dask/dask_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc
taskExecutionMetadata.OnGetPlatformResources().Return(&testPlatformResources)
taskExecutionMetadata.OnGetMaxAttempts().Return(uint32(1))
taskExecutionMetadata.OnIsInterruptible().Return(isInterruptible)
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
overrides := &mocks.TaskOverrides{}
overrides.OnGetResources().Return(resources)
taskExecutionMetadata.OnGetOverrides().Return(overrides)
Expand Down
1 change: 1 addition & 0 deletions go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ func dummyMPITaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecut
taskExecutionMetadata.OnGetOverrides().Return(resources)
taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount)
taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)
return taskCtx
}
Expand Down
1 change: 1 addition & 0 deletions go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskEx
taskExecutionMetadata.OnGetOverrides().Return(resources)
taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount)
taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)
return taskCtx
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.Tas
taskExecutionMetadata.OnGetOverrides().Return(resources)
taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount)
taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)
return taskCtx
}
Expand Down
1 change: 1 addition & 0 deletions go/tasks/plugins/k8s/pod/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore.
to.On("GetResources").Return(resources)
taskMetadata.On("GetOverrides").Return(to)
taskMetadata.On("IsInterruptible").Return(true)
taskMetadata.On("GetEnvironmentVariables").Return(nil)
return taskMetadata
}

Expand Down
1 change: 1 addition & 0 deletions go/tasks/plugins/k8s/pod/sidecar_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func dummySidecarTaskMetadata(resources *v1.ResourceRequirements) pluginsCore.Ta
to := &pluginsCoreMock.TaskOverrides{}
to.On("GetResources").Return(resources)
taskMetadata.On("GetOverrides").Return(to)
taskMetadata.On("GetEnvironmentVariables").Return(nil)

return taskMetadata
}
Expand Down
3 changes: 3 additions & 0 deletions go/tasks/plugins/k8s/ray/config_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

42 changes: 42 additions & 0 deletions go/tasks/plugins/k8s/ray/config_flags_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ func dummyRayTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecut
taskExecutionMetadata.OnGetSecurityContext().Return(core.SecurityContext{
RunAs: &core.Identity{K8SServiceAccount: serviceAccount},
})
taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil)
taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata)
return taskCtx
}
Expand Down
3 changes: 2 additions & 1 deletion go/tasks/plugins/k8s/spark/spark.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo
labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels()))
container := taskTemplate.GetContainer()

envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(container.GetEnv()), taskCtx.TaskExecutionMetadata().GetTaskExecutionID())
envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(container.GetEnv()),
taskCtx.TaskExecutionMetadata().GetEnvironmentVariables(), taskCtx.TaskExecutionMetadata().GetTaskExecutionID())

sparkEnvVars := make(map[string]string)
for _, envVar := range envVars {
Expand Down
1 change: 1 addition & 0 deletions go/tasks/plugins/k8s/spark/spark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool)
})
taskExecutionMetadata.On("IsInterruptible").Return(interruptible)
taskExecutionMetadata.On("GetMaxAttempts").Return(uint32(1))
taskExecutionMetadata.On("GetEnvironmentVariables").Return(nil)
taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata)
return taskCtx
}
Expand Down
1 change: 1 addition & 0 deletions tests/end_to_end.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i
})
tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{})
tMeta.OnGetInterruptibleFailureThreshold().Return(2)
tMeta.OnGetEnvironmentVariables().Return(nil)

catClient := &catalogMocks.Client{}
catData := sync.Map{}
Expand Down

0 comments on commit 1e0faab

Please sign in to comment.