diff --git a/cmd/config/config_test.go b/cmd/config/config_test.go index f11ee07a..8fd0674f 100644 --- a/cmd/config/config_test.go +++ b/cmd/config/config_test.go @@ -1,8 +1,12 @@ package config import ( + "context" + "net/url" + "os" "testing" + "github.com/flyteorg/flyte/flyteidl/clients/go/admin" "github.com/flyteorg/flytectl/pkg/printer" "github.com/stretchr/testify/assert" ) @@ -28,5 +32,21 @@ func TestInvalidOutputFormat(t *testing.T) { } }() result = c.MustOutputFormat() +} + +func TestUpdateConfigWithEnvVar(t *testing.T) { + originalValue := os.Getenv("FLYTE_ADMIN_ENDPOINT") + defer os.Setenv("FLYTE_ADMIN_ENDPOINT", originalValue) + + dummyURL := "dns://dummyHost" + os.Setenv("FLYTE_ADMIN_ENDPOINT", dummyURL) + parsedDummyURL, _ := url.Parse(dummyURL) + + adminCfg := admin.GetConfig(context.Background()) + + assert.NotEqual(t, adminCfg.Endpoint.URL, *parsedDummyURL) + err := UpdateConfigWithEnvVar() + assert.Nil(t, err) + assert.Equal(t, adminCfg.Endpoint.URL, *parsedDummyURL) } diff --git a/cmd/config/env_var_reader.go b/cmd/config/env_var_reader.go new file mode 100644 index 00000000..6ced9035 --- /dev/null +++ b/cmd/config/env_var_reader.go @@ -0,0 +1,46 @@ +package config + +import ( + "context" + "fmt" + "net/url" + "os" + + "github.com/flyteorg/flyte/flyteidl/clients/go/admin" + "github.com/flyteorg/flyte/flytestdlib/config" +) + +const flyteAdminEndpoint = "FLYTE_ADMIN_ENDPOINT" + +type UpdateFunc func(context.Context) error + +var envToUpdateFunc = map[string]UpdateFunc{flyteAdminEndpoint: updateAdminEndpoint} + +func UpdateConfigWithEnvVar() error { + ctx := context.Background() + + for envVar, updateFunc := range envToUpdateFunc { + if os.Getenv(envVar) != "" { + if err := updateFunc(ctx); err != nil { + return fmt.Errorf("error update config with env var: %v", err) + } + } + } + return nil +} + +func updateAdminEndpoint(ctx context.Context) error { + cfg := admin.GetConfig(ctx) + + if len(os.Getenv(flyteAdminEndpoint)) > 0 { + envEndpoint, err := url.Parse(os.Getenv(flyteAdminEndpoint)) + if err != nil { + return fmt.Errorf("error parsing env var %v: %v", flyteAdminEndpoint, err) + } + cfg.Endpoint = config.URL{URL: *envEndpoint} + if err := admin.SetConfig(cfg); err != nil { + return err + } + } + return nil +} diff --git a/cmd/root.go b/cmd/root.go index 418406c0..49ecad01 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -134,6 +134,10 @@ func initConfig(cmd *cobra.Command, _ []string) error { return err } + if err := config.UpdateConfigWithEnvVar(); err != nil { + return err + } + return nil }