diff --git a/internal/ceb/ceb.go b/internal/ceb/ceb.go index 1566e69e33f..c9e8dd4deac 100644 --- a/internal/ceb/ceb.go +++ b/internal/ceb/ceb.go @@ -301,24 +301,24 @@ func WithEnvDefaults() Option { cfg.URLServicePort = port cfg.ServerAddr = os.Getenv(envServerAddr) - cfg.ServerRequired, err = env.GetEnvBool(envCEBServerRequired, false) + cfg.ServerRequired, err = env.GetBool(envCEBServerRequired, false) if err != nil { return err } - cfg.ServerTls, err = env.GetEnvBool(envServerTls, false) + cfg.ServerTls, err = env.GetBool(envServerTls, false) if err != nil { return err } - cfg.ServerTlsSkipVerify, err = env.GetEnvBool(envServerTlsSkipVerify, false) + cfg.ServerTlsSkipVerify, err = env.GetBool(envServerTlsSkipVerify, false) if err != nil { return err } cfg.InviteToken = os.Getenv(envCEBToken) - cfg.disable, err = env.GetEnvBool(envCEBDisable, false) + cfg.disable, err = env.GetBool(envCEBDisable, false) if err != nil { return err } diff --git a/internal/cli/main.go b/internal/cli/main.go index 4dc47c62792..c20a119514f 100644 --- a/internal/cli/main.go +++ b/internal/cli/main.go @@ -140,7 +140,7 @@ func Commands( } // Set plain mode if set - outputModeBool, err := env.GetEnvBool(EnvPlain, false) + outputModeBool, err := env.GetBool(EnvPlain, false) if err != nil { log.Warn(err.Error()) } diff --git a/internal/env/env.go b/internal/env/env.go index 3db240a8f43..f53f5a3654f 100644 --- a/internal/env/env.go +++ b/internal/env/env.go @@ -7,9 +7,9 @@ import ( "strings" ) -// GetEnvBool Extracts a boolean from an env var. Falls back to the default +// GetBool Extracts a boolean from an env var. Falls back to the default // if the key is unset or not a valid boolean. -func GetEnvBool(key string, defaultValue bool) (bool, error) { +func GetBool(key string, defaultValue bool) (bool, error) { envVal := os.Getenv(key) if envVal == "" { return defaultValue, nil diff --git a/internal/env/env_test.go b/internal/env/env_test.go index b4ab4bb9781..14ea8858494 100644 --- a/internal/env/env_test.go +++ b/internal/env/env_test.go @@ -1,65 +1,88 @@ package env import ( - "github.com/stretchr/testify/require" "os" "testing" ) -func TestGetEnvBool(t *testing.T) { +func TestGetBool(t *testing.T) { envVarTestKey := "WAYPOINT_GET_ENV_BOOL_TEST" - require := require.New(t) - t.Run("Unset env var returns default", func(t *testing.T) { - b, err := GetEnvBool(envVarTestKey, true) - require.NoError(err) - require.True(b) - - b, err = GetEnvBool(envVarTestKey, false) - require.NoError(err) - require.False(b) - }) - - t.Run("Empty env var returns default", func(t *testing.T) { - os.Setenv(envVarTestKey, "") - b, err := GetEnvBool(envVarTestKey, true) - require.NoError(err) - require.True(b) - - b, err = GetEnvBool(envVarTestKey, false) - require.NoError(err) - require.False(b) - }) - - t.Run("Non-truthy env var returns an error", func(t *testing.T) { - os.Setenv(envVarTestKey, "unparseable") - _, err := GetEnvBool(envVarTestKey, true) - require.Error(err) - }) - - t.Run("true/false env vars return non-default", func(t *testing.T) { - os.Setenv(envVarTestKey, "true") - b, err := GetEnvBool(envVarTestKey, false) - require.NoError(err) - require.True(b) - - os.Setenv(envVarTestKey, "false") - b, err = GetEnvBool(envVarTestKey, true) - require.NoError(err) - require.False(b) - }) - - t.Run("boolean parsing is generous with capitalization", func(t *testing.T) { - os.Setenv(envVarTestKey, "tRuE") - b, err := GetEnvBool(envVarTestKey, false) - require.NoError(err) - require.True(b) - }) - - t.Run("1 evaluates as true", func(t *testing.T) { - os.Setenv(envVarTestKey, "1") - b, err := GetEnvBool(envVarTestKey, false) - require.NoError(err) - require.True(b) - }) + tests := []struct { + name string + defaultVal bool + envVal string + want bool + wantErr bool + }{ + { + name: "Empty env var returns default 1", + defaultVal: true, + envVal: "", + want: true, + wantErr: false, + }, + { + name: "Empty env var returns default 2", + defaultVal: false, + envVal: "", + want: false, + wantErr: false, + }, + { + name: "Non-truthy env var returns err", + defaultVal: false, + envVal: "unparseable", + want: false, + wantErr: true, + }, + { + name: "'true' is true", + defaultVal: false, + envVal: "true", + want: true, + wantErr: false, + }, + { + name: "'false' is true", + defaultVal: true, + envVal: "false", + want: false, + wantErr: false, + }, + { + name: "1 is true", + defaultVal: false, + envVal: "1", + want: true, + wantErr: false, + }, + { + name: "0 is false", + defaultVal: true, + envVal: "0", + want: false, + wantErr: false, + }, + { + name: "Boolean parsing ignores capitalization", + defaultVal: false, + envVal: "tRuE", + want: true, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Setenv(envVarTestKey, tt.envVal) + got, err := GetBool(envVarTestKey, tt.defaultVal) + if (err != nil) != tt.wantErr { + t.Errorf("GetBool() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("GetBool() got = %v, want %v", got, tt.want) + } + }) + } } diff --git a/internal/serverclient/client.go b/internal/serverclient/client.go index d186eb7e805..43943f9b058 100644 --- a/internal/serverclient/client.go +++ b/internal/serverclient/client.go @@ -128,12 +128,12 @@ func FromEnv() ConnectOption { if v := os.Getenv(EnvServerAddr); v != "" { c.Addr = v - c.Tls, err = env.GetEnvBool(EnvServerTls, false) + c.Tls, err = env.GetBool(EnvServerTls, false) if err != nil { return err } - c.TlsSkipVerify, err = env.GetEnvBool(EnvServerTlsSkipVerify, false) + c.TlsSkipVerify, err = env.GetBool(EnvServerTlsSkipVerify, false) if err != nil { return err }