diff --git a/go/tasks/pluginmachinery/core/exec_metadata.go b/go/tasks/pluginmachinery/core/exec_metadata.go index 2e39dda14..b3115a7a2 100644 --- a/go/tasks/pluginmachinery/core/exec_metadata.go +++ b/go/tasks/pluginmachinery/core/exec_metadata.go @@ -45,4 +45,5 @@ type TaskExecutionMetadata interface { IsInterruptible() bool GetPlatformResources() *v1.ResourceRequirements GetInterruptibleFailureThreshold() uint32 + GetEnvironmentVariables() map[string]string } diff --git a/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go b/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go index 41af602a6..29f0055dc 100644 --- a/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go +++ b/go/tasks/pluginmachinery/core/mocks/task_execution_metadata.go @@ -54,6 +54,40 @@ func (_m *TaskExecutionMetadata) GetAnnotations() map[string]string { return r0 } +type TaskExecutionMetadata_GetEnvironmentVariables struct { + *mock.Call +} + +func (_m TaskExecutionMetadata_GetEnvironmentVariables) Return(_a0 map[string]string) *TaskExecutionMetadata_GetEnvironmentVariables { + return &TaskExecutionMetadata_GetEnvironmentVariables{Call: _m.Call.Return(_a0)} +} + +func (_m *TaskExecutionMetadata) OnGetEnvironmentVariables() *TaskExecutionMetadata_GetEnvironmentVariables { + c_call := _m.On("GetEnvironmentVariables") + return &TaskExecutionMetadata_GetEnvironmentVariables{Call: c_call} +} + +func (_m *TaskExecutionMetadata) OnGetEnvironmentVariablesMatch(matchers ...interface{}) *TaskExecutionMetadata_GetEnvironmentVariables { + c_call := _m.On("GetEnvironmentVariables", matchers...) + return &TaskExecutionMetadata_GetEnvironmentVariables{Call: c_call} +} + +// GetEnvironmentVariables provides a mock function with given fields: +func (_m *TaskExecutionMetadata) GetEnvironmentVariables() map[string]string { + ret := _m.Called() + + var r0 map[string]string + if rf, ok := ret.Get(0).(func() map[string]string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[string]string) + } + } + + return r0 +} + type TaskExecutionMetadata_GetInterruptibleFailureThreshold struct { *mock.Call } diff --git a/go/tasks/pluginmachinery/flytek8s/container_helper.go b/go/tasks/pluginmachinery/flytek8s/container_helper.go index 193bf2d8a..460d33b71 100755 --- a/go/tasks/pluginmachinery/flytek8s/container_helper.go +++ b/go/tasks/pluginmachinery/flytek8s/container_helper.go @@ -296,7 +296,7 @@ func AddFlyteCustomizationsToContainer(ctx context.Context, parameters template. } container.Args = modifiedArgs - container.Env = DecorateEnvVars(ctx, container.Env, parameters.TaskExecMetadata.GetTaskExecutionID()) + container.Env = DecorateEnvVars(ctx, container.Env, parameters.TaskExecMetadata.GetEnvironmentVariables(), parameters.TaskExecMetadata.GetTaskExecutionID()) if parameters.TaskExecMetadata.GetOverrides() != nil && parameters.TaskExecMetadata.GetOverrides().GetResources() != nil { res := parameters.TaskExecMetadata.GetOverrides().GetResources() diff --git a/go/tasks/pluginmachinery/flytek8s/container_helper_test.go b/go/tasks/pluginmachinery/flytek8s/container_helper_test.go index f0ef5feaf..be575d48c 100755 --- a/go/tasks/pluginmachinery/flytek8s/container_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/container_helper_test.go @@ -387,6 +387,9 @@ func TestToK8sContainer(t *testing.T) { mockTaskExecutionID.OnGetGeneratedName().Return("gen_name") mockTaskExecMetadata.OnGetTaskExecutionID().Return(&mockTaskExecutionID) mockTaskExecMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) + mockTaskExecMetadata.OnGetEnvironmentVariables().Return(map[string]string{ + "foo": "bar", + }) tCtx := &mocks.TaskExecutionContext{} tCtx.OnTaskExecutionMetadata().Return(&mockTaskExecMetadata) @@ -419,6 +422,10 @@ func TestToK8sContainer(t *testing.T) { Name: "k", Value: "v", }, + { + Name: "foo", + Value: "bar", + }, }, container.Env) errs := validation.IsDNS1123Label(container.Name) assert.Nil(t, errs) @@ -454,6 +461,7 @@ func getTemplateParametersForTest(resourceRequirements, platformResources *v1.Re mockOverrides.OnGetResources().Return(resourceRequirements) mockTaskExecMetadata.OnGetOverrides().Return(&mockOverrides) mockTaskExecMetadata.OnGetPlatformResources().Return(platformResources) + mockTaskExecMetadata.OnGetEnvironmentVariables().Return(nil) mockInputReader := mocks2.InputReader{} mockInputPath := storage.DataReference("s3://input/path") diff --git a/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds.go b/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds.go index bb2b9db23..49efef455 100755 --- a/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds.go +++ b/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds.go @@ -115,10 +115,13 @@ func GetExecutionEnvVars(id pluginsCore.TaskExecutionID) []v1.EnvVar { return envVars } -func DecorateEnvVars(ctx context.Context, envVars []v1.EnvVar, id pluginsCore.TaskExecutionID) []v1.EnvVar { +func DecorateEnvVars(ctx context.Context, envVars []v1.EnvVar, taskEnvironmentVariables map[string]string, id pluginsCore.TaskExecutionID) []v1.EnvVar { envVars = append(envVars, GetContextEnvVars(ctx)...) envVars = append(envVars, GetExecutionEnvVars(id)...) + for k, v := range taskEnvironmentVariables { + envVars = append(envVars, v1.EnvVar{Name: k, Value: v}) + } for k, v := range config.GetK8sPluginConfig().DefaultEnvVars { envVars = append(envVars, v1.EnvVar{Name: k, Value: v}) } diff --git a/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go b/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go index e560c2b43..866b04205 100755 --- a/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go +++ b/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go @@ -266,11 +266,13 @@ func TestDecorateEnvVars(t *testing.T) { args args additionEnvVar map[string]string additionEnvVarFromEnv map[string]string + executionEnvVar map[string]string want []v12.EnvVar }{ - {"no-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, expected}, - {"with-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, additionalEnv, emptyEnvVar, aggregated}, - {"from-env", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, envVarsFromEnv, aggregated}, + {"no-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, emptyEnvVar, expected}, + {"with-additional", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, additionalEnv, emptyEnvVar, emptyEnvVar, aggregated}, + {"from-env", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, envVarsFromEnv, emptyEnvVar, aggregated}, + {"from-execution-metadata", args{envVars: defaultEnv, id: mockTaskExecutionIdentifier{}}, emptyEnvVar, emptyEnvVar, additionalEnv, aggregated}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -278,7 +280,7 @@ func TestDecorateEnvVars(t *testing.T) { DefaultEnvVars: tt.additionEnvVar, DefaultEnvVarsFromEnv: tt.additionEnvVarFromEnv, })) - if got := DecorateEnvVars(ctx, tt.args.envVars, tt.args.id); !reflect.DeepEqual(got, tt.want) { + if got := DecorateEnvVars(ctx, tt.args.envVars, tt.executionEnvVar, tt.args.id); !reflect.DeepEqual(got, tt.want) { t.Errorf("DecorateEnvVars() = %v, want %v", got, tt.want) } }) diff --git a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go index bb99bafb5..f469a2227 100755 --- a/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go +++ b/go/tasks/pluginmachinery/flytek8s/pod_helper_test.go @@ -57,6 +57,7 @@ func dummyTaskExecutionMetadata(resources *v1.ResourceRequirements) pluginsCore. taskExecutionMetadata.On("GetOverrides").Return(to) taskExecutionMetadata.On("IsInterruptible").Return(true) taskExecutionMetadata.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) + taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) return taskExecutionMetadata } diff --git a/go/tasks/plugins/array/awsbatch/transformer.go b/go/tasks/plugins/array/awsbatch/transformer.go index 6b6da84fb..e3cd7041f 100644 --- a/go/tasks/plugins/array/awsbatch/transformer.go +++ b/go/tasks/plugins/array/awsbatch/transformer.go @@ -127,7 +127,7 @@ func UpdateBatchInputForArray(_ context.Context, batchInput *batch.SubmitJobInpu func getEnvVarsForTask(ctx context.Context, execID pluginCore.TaskExecutionID, containerEnvVars []*core.KeyValuePair, defaultEnvVars map[string]string) []v1.EnvVar { - envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(containerEnvVars), execID) + envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(containerEnvVars), nil, execID) m := make(map[string]string, len(envVars)) for _, envVar := range envVars { m[envVar.Name] = envVar.Value diff --git a/go/tasks/plugins/array/k8s/management_test.go b/go/tasks/plugins/array/k8s/management_test.go index a2dd715fa..9c17331b4 100644 --- a/go/tasks/plugins/array/k8s/management_test.go +++ b/go/tasks/plugins/array/k8s/management_test.go @@ -107,6 +107,7 @@ func getMockTaskExecutionContext(ctx context.Context, parallelism int) *mocks.Ta tMeta.OnGetOwnerReference().Return(metav1.OwnerReference{}) tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) tMeta.OnGetInterruptibleFailureThreshold().Return(2) + tMeta.OnGetEnvironmentVariables().Return(nil) ow := &mocks2.OutputWriter{} ow.OnGetOutputPrefixPath().Return("/prefix/") diff --git a/go/tasks/plugins/k8s/dask/dask_test.go b/go/tasks/plugins/k8s/dask/dask_test.go index 2eb36ad3b..e60400219 100644 --- a/go/tasks/plugins/k8s/dask/dask_test.go +++ b/go/tasks/plugins/k8s/dask/dask_test.go @@ -164,6 +164,7 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc taskExecutionMetadata.OnGetPlatformResources().Return(&testPlatformResources) taskExecutionMetadata.OnGetMaxAttempts().Return(uint32(1)) taskExecutionMetadata.OnIsInterruptible().Return(isInterruptible) + taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) overrides := &mocks.TaskOverrides{} overrides.OnGetResources().Return(resources) taskExecutionMetadata.OnGetOverrides().Return(overrides) diff --git a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 29fefd9ca..2e9e9283a 100644 --- a/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -144,6 +144,7 @@ func dummyMPITaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecut taskExecutionMetadata.OnGetOverrides().Return(resources) taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) + taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) return taskCtx } diff --git a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 150bdb59a..74bd3fe92 100644 --- a/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -163,6 +163,7 @@ func dummyPytorchTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskEx taskExecutionMetadata.OnGetOverrides().Return(resources) taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) + taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) return taskCtx } diff --git a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index dc8d5f240..37f22bf34 100644 --- a/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -146,6 +146,7 @@ func dummyTensorFlowTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.Tas taskExecutionMetadata.OnGetOverrides().Return(resources) taskExecutionMetadata.OnGetK8sServiceAccount().Return(serviceAccount) taskExecutionMetadata.OnGetPlatformResources().Return(&corev1.ResourceRequirements{}) + taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) return taskCtx } diff --git a/go/tasks/plugins/k8s/pod/container_test.go b/go/tasks/plugins/k8s/pod/container_test.go index 0624c817c..13600168a 100644 --- a/go/tasks/plugins/k8s/pod/container_test.go +++ b/go/tasks/plugins/k8s/pod/container_test.go @@ -66,6 +66,7 @@ func dummyContainerTaskMetadata(resources *v1.ResourceRequirements) pluginsCore. to.On("GetResources").Return(resources) taskMetadata.On("GetOverrides").Return(to) taskMetadata.On("IsInterruptible").Return(true) + taskMetadata.On("GetEnvironmentVariables").Return(nil) return taskMetadata } diff --git a/go/tasks/plugins/k8s/pod/sidecar_test.go b/go/tasks/plugins/k8s/pod/sidecar_test.go index 1dc3ac9f1..b11fd623b 100644 --- a/go/tasks/plugins/k8s/pod/sidecar_test.go +++ b/go/tasks/plugins/k8s/pod/sidecar_test.go @@ -93,6 +93,7 @@ func dummySidecarTaskMetadata(resources *v1.ResourceRequirements) pluginsCore.Ta to := &pluginsCoreMock.TaskOverrides{} to.On("GetResources").Return(resources) taskMetadata.On("GetOverrides").Return(to) + taskMetadata.On("GetEnvironmentVariables").Return(nil) return taskMetadata } diff --git a/go/tasks/plugins/k8s/ray/config_flags.go b/go/tasks/plugins/k8s/ray/config_flags.go index 6f651a3d2..f8e983056 100755 --- a/go/tasks/plugins/k8s/ray/config_flags.go +++ b/go/tasks/plugins/k8s/ray/config_flags.go @@ -56,5 +56,8 @@ func (cfg Config) GetPFlagSet(prefix string) *pflag.FlagSet { cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "includeDashboard"), defaultConfig.IncludeDashboard, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "dashboardHost"), defaultConfig.DashboardHost, "") cmdFlags.String(fmt.Sprintf("%v%v", prefix, "nodeIPAddress"), defaultConfig.NodeIPAddress, "") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.name"), defaultConfig.RemoteClusterConfig.Name, "Friendly name of the remote cluster") + cmdFlags.String(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.endpoint"), defaultConfig.RemoteClusterConfig.Endpoint, " Remote K8s cluster endpoint") + cmdFlags.Bool(fmt.Sprintf("%v%v", prefix, "remoteClusterConfig.enabled"), defaultConfig.RemoteClusterConfig.Enabled, " Boolean flag to enable or disable") return cmdFlags } diff --git a/go/tasks/plugins/k8s/ray/config_flags_test.go b/go/tasks/plugins/k8s/ray/config_flags_test.go index d5a59757c..60761b900 100755 --- a/go/tasks/plugins/k8s/ray/config_flags_test.go +++ b/go/tasks/plugins/k8s/ray/config_flags_test.go @@ -183,4 +183,46 @@ func TestConfig_SetFlags(t *testing.T) { } }) }) + t.Run("Test_remoteClusterConfig.name", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("remoteClusterConfig.name", testValue) + if vString, err := cmdFlags.GetString("remoteClusterConfig.name"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RemoteClusterConfig.Name) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_remoteClusterConfig.endpoint", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("remoteClusterConfig.endpoint", testValue) + if vString, err := cmdFlags.GetString("remoteClusterConfig.endpoint"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vString), &actual.RemoteClusterConfig.Endpoint) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) + t.Run("Test_remoteClusterConfig.enabled", func(t *testing.T) { + + t.Run("Override", func(t *testing.T) { + testValue := "1" + + cmdFlags.Set("remoteClusterConfig.enabled", testValue) + if vBool, err := cmdFlags.GetBool("remoteClusterConfig.enabled"); err == nil { + testDecodeJson_Config(t, fmt.Sprintf("%v", vBool), &actual.RemoteClusterConfig.Enabled) + + } else { + assert.FailNow(t, err.Error()) + } + }) + }) } diff --git a/go/tasks/plugins/k8s/ray/ray_test.go b/go/tasks/plugins/k8s/ray/ray_test.go index 99da77b6e..e3e4b5585 100644 --- a/go/tasks/plugins/k8s/ray/ray_test.go +++ b/go/tasks/plugins/k8s/ray/ray_test.go @@ -141,6 +141,7 @@ func dummyRayTaskContext(taskTemplate *core.TaskTemplate) pluginsCore.TaskExecut taskExecutionMetadata.OnGetSecurityContext().Return(core.SecurityContext{ RunAs: &core.Identity{K8SServiceAccount: serviceAccount}, }) + taskExecutionMetadata.OnGetEnvironmentVariables().Return(nil) taskCtx.OnTaskExecutionMetadata().Return(taskExecutionMetadata) return taskCtx } diff --git a/go/tasks/plugins/k8s/spark/spark.go b/go/tasks/plugins/k8s/spark/spark.go index 66600709a..474ce819c 100755 --- a/go/tasks/plugins/k8s/spark/spark.go +++ b/go/tasks/plugins/k8s/spark/spark.go @@ -84,7 +84,8 @@ func (sparkResourceHandler) BuildResource(ctx context.Context, taskCtx pluginsCo labels := utils.UnionMaps(config.GetK8sPluginConfig().DefaultLabels, utils.CopyMap(taskCtx.TaskExecutionMetadata().GetLabels())) container := taskTemplate.GetContainer() - envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(container.GetEnv()), taskCtx.TaskExecutionMetadata().GetTaskExecutionID()) + envVars := flytek8s.DecorateEnvVars(ctx, flytek8s.ToK8sEnvVar(container.GetEnv()), + taskCtx.TaskExecutionMetadata().GetEnvironmentVariables(), taskCtx.TaskExecutionMetadata().GetTaskExecutionID()) sparkEnvVars := make(map[string]string) for _, envVar := range envVars { diff --git a/go/tasks/plugins/k8s/spark/spark_test.go b/go/tasks/plugins/k8s/spark/spark_test.go index 6b6ed18d8..aef61c0f7 100755 --- a/go/tasks/plugins/k8s/spark/spark_test.go +++ b/go/tasks/plugins/k8s/spark/spark_test.go @@ -325,6 +325,7 @@ func dummySparkTaskContext(taskTemplate *core.TaskTemplate, interruptible bool) }) taskExecutionMetadata.On("IsInterruptible").Return(interruptible) taskExecutionMetadata.On("GetMaxAttempts").Return(uint32(1)) + taskExecutionMetadata.On("GetEnvironmentVariables").Return(nil) taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata) return taskCtx } diff --git a/tests/end_to_end.go b/tests/end_to_end.go index eacffb038..dac473dfd 100644 --- a/tests/end_to_end.go +++ b/tests/end_to_end.go @@ -177,6 +177,7 @@ func RunPluginEndToEndTest(t *testing.T, executor pluginCore.Plugin, template *i }) tMeta.OnGetPlatformResources().Return(&v1.ResourceRequirements{}) tMeta.OnGetInterruptibleFailureThreshold().Return(2) + tMeta.OnGetEnvironmentVariables().Return(nil) catClient := &catalogMocks.Client{} catData := sync.Map{}