diff --git a/cmd/config_printer/config_summary.go b/cmd/config_printer/config_summary.go index 0a76a8939..26e16a433 100644 --- a/cmd/config_printer/config_summary.go +++ b/cmd/config_printer/config_summary.go @@ -21,7 +21,6 @@ package config_printer import ( "reflect" - log "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -32,9 +31,7 @@ func configSummary(cmd *cobra.Command, args []string) { defaultConfig := viper.New() config.SetBaseDefaultsInConfig(defaultConfig) - if err := config.InitConfigDir(defaultConfig); err != nil { - log.Errorf("Error initializing config directory: %v", err) - } + config.InitConfigDir(defaultConfig) defaultConfigMap := initClientAndServerConfig(defaultConfig) currentConfigMap := initClientAndServerConfig(viper.GetViper()) diff --git a/config/config.go b/config/config.go index 353cdec28..8a2fb14c6 100644 --- a/config/config.go +++ b/config/config.go @@ -498,13 +498,14 @@ func CleanupTempResources() (err error) { return } -func getConfigBase() (string, error) { +func getConfigBase() string { home, err := os.UserHomeDir() if err != nil { log.Warningln("No home directory found for user -- will check for configuration yaml in /etc/pelican/") + return filepath.Join("/etc", "pelican") } - return filepath.Join(home, ".config", "pelican"), nil + return filepath.Join(home, ".config", "pelican") } func setupTransport() { @@ -789,23 +790,19 @@ func SetBaseDefaultsInConfig(v *viper.Viper) { } // For the given Viper instance, set the default config directory. -func InitConfigDir(v *viper.Viper) error { +func InitConfigDir(v *viper.Viper) { configDir := v.GetString("ConfigDir") if configDir == "" { if IsRootExecution() { configDir = "/etc/pelican" } else { - configTmp, err := getConfigBase() - if err != nil { - return err - } + configTmp := getConfigBase() configDir = configTmp } v.SetDefault("ConfigDir", configDir) } v.SetConfigName("pelican") - return nil } // InitConfig sets up the global Viper instance by loading defaults and @@ -818,10 +815,7 @@ func InitConfig() { // Set default values in the global Viper instance SetBaseDefaultsInConfig(viper.GetViper()) - if err := InitConfigDir(viper.GetViper()); err != nil { - log.Errorf("Failed to initialize the config directory, Error: %v", err) - os.Exit(1) - } + InitConfigDir(viper.GetViper()) if configFile := viper.GetString("config"); configFile != "" { viper.SetConfigFile(configFile) diff --git a/config/config_test.go b/config/config_test.go index d270dfff4..a7422f230 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -25,6 +25,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "runtime" "sort" "strings" "testing" @@ -180,6 +181,92 @@ func TestInitConfig(t *testing.T) { assert.Equal(t, "", param.Federation_DiscoveryUrl.GetString()) } +// Helper function to set and reset the environment variable HOME/USERPROFILE +func setHomeDirEnv(t *testing.T, mockHomeDir string) func() { + // Save the original environment variables for restoration + originalHome := os.Getenv("HOME") + + // Set the appropriate environment variable based on the platform + if err := os.Setenv("HOME", mockHomeDir); err != nil { + t.Fatalf("Failed to set HOME: %v", err) + } + + // Return a function to restore the environment variables + return func() { + // Restore HOME + if originalHome != "" { + os.Setenv("HOME", originalHome) + } else { + os.Unsetenv("HOME") + } + } +} + +func TestHomeDir(t *testing.T) { + if runtime.GOOS != "windows" { + mockHomeDir := filepath.Join("test", "configDir") + ResetConfig() + t.Cleanup(func() { + ResetConfig() + }) + + // Save the original environment variables + oldConfigRoot := isRootExec + resetEnv := setHomeDirEnv(t, mockHomeDir) + + defer func() { + resetEnv() + isRootExec = oldConfigRoot + }() + + t.Run("RootUserNoConfigDir", func(t *testing.T) { + isRootExec = true + + InitConfigDir(viper.GetViper()) + + cDir := viper.GetString("ConfigDir") + require.Equal(t, "/etc/pelican", cDir) + }) + + t.Run("WithConfigDir", func(t *testing.T) { + viper.Set("ConfigDir", "/test/configDir") + + InitConfigDir(viper.GetViper()) + + cDir := viper.GetString("ConfigDir") + require.Equal(t, "/test/configDir", cDir) + }) + + t.Run("NonRootNoConfigDirWithHomeSet", func(t *testing.T) { + isRootExec = false + viper.Reset() + + InitConfigDir(viper.GetViper()) + + cDir := viper.GetString("ConfigDir") + require.Equal(t, filepath.Join("test", "configDir", ".config", "pelican"), cDir) + }) + + t.Run("NonRootNoConfigDirWithNoHome", func(t *testing.T) { + isRootExec = false + viper.Reset() + os.Unsetenv("HOME") + + InitConfigDir(viper.GetViper()) + + cDir := viper.GetString("ConfigDir") + require.Equal(t, filepath.Join("/etc", "pelican"), cDir) + }) + } + +} + +// HOME directory test +// Root and set +// Root and unset +// Non root and set +// Non root and unset + // Helper func for TestExtraCfg // // Sets up the root config file and adds the ConfigLocations key to point to a test's tempdir @@ -603,9 +690,8 @@ func TestInitServerUrl(t *testing.T) { ResetConfig() viper.Set("Server.Hostname", mockHostname) viper.Set("Server.WebPort", mockNon443Port) - err := InitConfigDir(viper.GetViper()) - require.NoError(t, err) - err = InitServer(context.Background(), 0) + InitConfigDir(viper.GetViper()) + err := InitServer(context.Background(), 0) require.NoError(t, err) assert.Equal(t, mockWebUrlWNon443Port, param.Server_ExternalWebUrl.GetString()) }) @@ -614,9 +700,8 @@ func TestInitServerUrl(t *testing.T) { ResetConfig() viper.Set("Server.Hostname", mockHostname) viper.Set("Server.WebPort", mock443Port) - err := InitConfigDir(viper.GetViper()) - require.NoError(t, err) - err = InitServer(context.Background(), 0) + InitConfigDir(viper.GetViper()) + err := InitServer(context.Background(), 0) require.NoError(t, err) assert.Equal(t, mockWebUrlWoPort, param.Server_ExternalWebUrl.GetString()) }) @@ -625,9 +710,8 @@ func TestInitServerUrl(t *testing.T) { // We respect the URL value set directly by others. Won't remove 443 port ResetConfig() viper.Set("Server.ExternalWebUrl", mockWebUrlW443Port) - err := InitConfigDir(viper.GetViper()) - require.NoError(t, err) - err = InitServer(context.Background(), 0) + InitConfigDir(viper.GetViper()) + err := InitServer(context.Background(), 0) require.NoError(t, err) assert.Equal(t, mockWebUrlWoPort, param.Server_ExternalWebUrl.GetString()) }) diff --git a/registry/client_commands_test.go b/registry/client_commands_test.go index e9d96fa95..7adc1c900 100644 --- a/registry/client_commands_test.go +++ b/registry/client_commands_test.go @@ -47,10 +47,9 @@ func registryMockup(ctx context.Context, t *testing.T, testName string) *httptes viper.Set("IssuerKey", ikey) viper.Set("Registry.DbLocation", filepath.Join(issuerTempDir, "test.sql")) viper.Set("Server.WebPort", 8444) - err := config.InitConfigDir(viper.GetViper()) - require.NoError(t, err) + config.InitConfigDir(viper.GetViper()) - err = config.InitServer(ctx, server_structs.RegistryType) + err := config.InitServer(ctx, server_structs.RegistryType) require.NoError(t, err) setupMockRegistryDB(t) diff --git a/xrootd/authorization_test.go b/xrootd/authorization_test.go index d420eb924..940d4014a 100644 --- a/xrootd/authorization_test.go +++ b/xrootd/authorization_test.go @@ -502,8 +502,7 @@ func TestMergeConfig(t *testing.T) { err := os.WriteFile(scitokensConfigFile, []byte(configInput), fs.FileMode(0600)) require.NoError(t, err) - err = config.InitConfigDir(viper.GetViper()) - require.NoError(t, err) + config.InitConfigDir(viper.GetViper()) err = config.InitServer(ctx, server_structs.OriginType) require.NoError(t, err) @@ -540,8 +539,7 @@ func TestGenerateConfig(t *testing.T) { viper.Set("Origin.Port", 8443) viper.Set("Server.WebPort", 8443) viper.Set(param.Origin_StorageType.GetName(), string(server_structs.OriginStoragePosix)) - err = config.InitConfigDir(viper.GetViper()) - require.NoError(t, err) + config.InitConfigDir(viper.GetViper()) err = config.InitServer(ctx, server_structs.OriginType) require.NoError(t, err) issuer, err = GenerateMonitoringIssuer() @@ -558,8 +556,7 @@ func TestGenerateConfig(t *testing.T) { viper.Set("Origin.ScitokensMapSubject", true) viper.Set("Origin.Port", 8443) viper.Set("Server.WebPort", 8443) - err = config.InitConfigDir(viper.GetViper()) - require.NoError(t, err) + config.InitConfigDir(viper.GetViper()) err = config.InitServer(ctx, server_structs.OriginType) require.NoError(t, err) issuer, err = GenerateOriginIssuer([]string{"/foo/bar/baz", "/another/exported/path"})