Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup changes made from $HOME fixes from before #1832

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions cmd/config_printer/config_summary.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ package config_printer
import (
"reflect"

log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"github.com/spf13/viper"

Expand All @@ -32,9 +31,7 @@ func configSummary(cmd *cobra.Command, args []string) {
defaultConfig := viper.New()
config.SetBaseDefaultsInConfig(defaultConfig)

if err := config.InitConfigDir(defaultConfig); err != nil {
log.Errorf("Error initializing config directory: %v", err)
}
config.InitConfigDir(defaultConfig)

defaultConfigMap := initClientAndServerConfig(defaultConfig)
currentConfigMap := initClientAndServerConfig(viper.GetViper())
Expand Down
18 changes: 6 additions & 12 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -498,13 +498,14 @@ func CleanupTempResources() (err error) {
return
}

func getConfigBase() (string, error) {
func getConfigBase() string {
home, err := os.UserHomeDir()
if err != nil {
log.Warningln("No home directory found for user -- will check for configuration yaml in /etc/pelican/")
return filepath.Join("/etc", "pelican")
}

return filepath.Join(home, ".config", "pelican"), nil
return filepath.Join(home, ".config", "pelican")
}

func setupTransport() {
Expand Down Expand Up @@ -789,23 +790,19 @@ func SetBaseDefaultsInConfig(v *viper.Viper) {
}

// For the given Viper instance, set the default config directory.
func InitConfigDir(v *viper.Viper) error {
func InitConfigDir(v *viper.Viper) {

configDir := v.GetString("ConfigDir")
if configDir == "" {
if IsRootExecution() {
configDir = "/etc/pelican"
} else {
configTmp, err := getConfigBase()
if err != nil {
return err
}
configTmp := getConfigBase()
configDir = configTmp
}
v.SetDefault("ConfigDir", configDir)
}
v.SetConfigName("pelican")
return nil
}

// InitConfig sets up the global Viper instance by loading defaults and
Expand All @@ -818,10 +815,7 @@ func InitConfig() {
// Set default values in the global Viper instance
SetBaseDefaultsInConfig(viper.GetViper())

if err := InitConfigDir(viper.GetViper()); err != nil {
log.Errorf("Failed to initialize the config directory, Error: %v", err)
os.Exit(1)
}
InitConfigDir(viper.GetViper())

if configFile := viper.GetString("config"); configFile != "" {
viper.SetConfigFile(configFile)
Expand Down
118 changes: 109 additions & 9 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"net/http/httptest"
"os"
"path/filepath"
"runtime"
"sort"
"strings"
"testing"
Expand Down Expand Up @@ -180,6 +181,108 @@ func TestInitConfig(t *testing.T) {
assert.Equal(t, "", param.Federation_DiscoveryUrl.GetString())
}

// Helper function to set and reset the environment variable HOME/USERPROFILE
func setHomeDirEnv(t *testing.T, mockHomeDir string) func() {
// Save the original environment variables for restoration
originalHome := os.Getenv("HOME")
originalUserProfile := os.Getenv("USERPROFILE")

// Set the appropriate environment variable based on the platform
switch runtime.GOOS {
case "windows":
if err := os.Setenv("USERPROFILE", mockHomeDir); err != nil {
t.Fatalf("Failed to set USERPROFILE: %v", err)
}
case "darwin", "linux":
if err := os.Setenv("HOME", mockHomeDir); err != nil {
t.Fatalf("Failed to set HOME: %v", err)
}
}

// Return a function to restore the environment variables
return func() {
if runtime.GOOS == "windows" {
// Restore USERPROFILE
if originalUserProfile != "" {
os.Setenv("USERPROFILE", originalUserProfile)
} else {
os.Unsetenv("USERPROFILE")
}
} else {
// Restore HOME
if originalHome != "" {
os.Setenv("HOME", originalHome)
} else {
os.Unsetenv("HOME")
}
}
}
}

func TestHomeDir(t *testing.T) {
mockHomeDir := filepath.Join("/test", "configDir")
ResetConfig()
t.Cleanup(func() {
ResetConfig()
})

// Save the original environment variables
oldConfigRoot := isRootExec
resetEnv := setHomeDirEnv(t, mockHomeDir)

defer func() {
resetEnv()
isRootExec = oldConfigRoot
}()

t.Run("RootUserNoConfigDir", func(t *testing.T) {
isRootExec = true

InitConfigDir(viper.GetViper())

cDir := viper.GetString("ConfigDir")
require.Equal(t, "/etc/pelican", cDir)
})

t.Run("WithConfigDir", func(t *testing.T) {
viper.Set("ConfigDir", "/test/configDir")

InitConfigDir(viper.GetViper())

cDir := viper.GetString("ConfigDir")
require.Equal(t, "/test/configDir", cDir)
})

t.Run("NonRootNoConfigDirWithHomeSet", func(t *testing.T) {
isRootExec = false
viper.Reset()

InitConfigDir(viper.GetViper())

cDir := viper.GetString("ConfigDir")
require.Equal(t, "/test/configDir/.config/pelican", cDir)
})

t.Run("NonRootNoConfigDirWithNoHome", func(t *testing.T) {
isRootExec = false
viper.Reset()
os.Unsetenv("HOME")
os.Unsetenv("USERPROFILE")

InitConfigDir(viper.GetViper())

cDir := viper.GetString("ConfigDir")
require.Equal(t, "/etc/pelican", cDir)
})

}

// HOME directory test
// Root and set
// Root and unset
// Non root and set
// Non root and unset

// Helper func for TestExtraCfg
//
// Sets up the root config file and adds the ConfigLocations key to point to a test's tempdir
Expand Down Expand Up @@ -603,9 +706,8 @@ func TestInitServerUrl(t *testing.T) {
ResetConfig()
viper.Set("Server.Hostname", mockHostname)
viper.Set("Server.WebPort", mockNon443Port)
err := InitConfigDir(viper.GetViper())
require.NoError(t, err)
err = InitServer(context.Background(), 0)
InitConfigDir(viper.GetViper())
err := InitServer(context.Background(), 0)
require.NoError(t, err)
assert.Equal(t, mockWebUrlWNon443Port, param.Server_ExternalWebUrl.GetString())
})
Expand All @@ -614,9 +716,8 @@ func TestInitServerUrl(t *testing.T) {
ResetConfig()
viper.Set("Server.Hostname", mockHostname)
viper.Set("Server.WebPort", mock443Port)
err := InitConfigDir(viper.GetViper())
require.NoError(t, err)
err = InitServer(context.Background(), 0)
InitConfigDir(viper.GetViper())
err := InitServer(context.Background(), 0)
require.NoError(t, err)
assert.Equal(t, mockWebUrlWoPort, param.Server_ExternalWebUrl.GetString())
})
Expand All @@ -625,9 +726,8 @@ func TestInitServerUrl(t *testing.T) {
// We respect the URL value set directly by others. Won't remove 443 port
ResetConfig()
viper.Set("Server.ExternalWebUrl", mockWebUrlW443Port)
err := InitConfigDir(viper.GetViper())
require.NoError(t, err)
err = InitServer(context.Background(), 0)
InitConfigDir(viper.GetViper())
err := InitServer(context.Background(), 0)
require.NoError(t, err)
assert.Equal(t, mockWebUrlWoPort, param.Server_ExternalWebUrl.GetString())
})
Expand Down
5 changes: 2 additions & 3 deletions registry/client_commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,9 @@ func registryMockup(ctx context.Context, t *testing.T, testName string) *httptes
viper.Set("IssuerKey", ikey)
viper.Set("Registry.DbLocation", filepath.Join(issuerTempDir, "test.sql"))
viper.Set("Server.WebPort", 8444)
err := config.InitConfigDir(viper.GetViper())
require.NoError(t, err)
config.InitConfigDir(viper.GetViper())

err = config.InitServer(ctx, server_structs.RegistryType)
err := config.InitServer(ctx, server_structs.RegistryType)
require.NoError(t, err)

setupMockRegistryDB(t)
Expand Down
9 changes: 3 additions & 6 deletions xrootd/authorization_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -502,8 +502,7 @@ func TestMergeConfig(t *testing.T) {
err := os.WriteFile(scitokensConfigFile, []byte(configInput), fs.FileMode(0600))
require.NoError(t, err)

err = config.InitConfigDir(viper.GetViper())
require.NoError(t, err)
config.InitConfigDir(viper.GetViper())
err = config.InitServer(ctx, server_structs.OriginType)
require.NoError(t, err)

Expand Down Expand Up @@ -540,8 +539,7 @@ func TestGenerateConfig(t *testing.T) {
viper.Set("Origin.Port", 8443)
viper.Set("Server.WebPort", 8443)
viper.Set(param.Origin_StorageType.GetName(), string(server_structs.OriginStoragePosix))
err = config.InitConfigDir(viper.GetViper())
require.NoError(t, err)
config.InitConfigDir(viper.GetViper())
err = config.InitServer(ctx, server_structs.OriginType)
require.NoError(t, err)
issuer, err = GenerateMonitoringIssuer()
Expand All @@ -558,8 +556,7 @@ func TestGenerateConfig(t *testing.T) {
viper.Set("Origin.ScitokensMapSubject", true)
viper.Set("Origin.Port", 8443)
viper.Set("Server.WebPort", 8443)
err = config.InitConfigDir(viper.GetViper())
require.NoError(t, err)
config.InitConfigDir(viper.GetViper())
err = config.InitServer(ctx, server_structs.OriginType)
require.NoError(t, err)
issuer, err = GenerateOriginIssuer([]string{"/foo/bar/baz", "/another/exported/path"})
Expand Down
Loading