diff --git a/e2e/testdata/custom-image-test/build.envd b/e2e/testdata/custom-image-test/build.envd index 418dc759b..b56492fe6 100644 --- a/e2e/testdata/custom-image-test/build.envd +++ b/e2e/testdata/custom-image-test/build.envd @@ -1,5 +1,5 @@ def build(): base(language="python", image="python:3.8") install.python_packages(name=[ - "via", + "via", ]) diff --git a/go.mod b/go.mod index 6c92d9823..fff5f91bd 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/tensorchord/envd go 1.18 require ( - github.com/adrg/xdg v0.4.0 github.com/alessio/shellescape v1.4.1 github.com/cockroachdb/errors v1.9.0 github.com/containerd/console v1.0.3 diff --git a/go.sum b/go.sum index 80f4c737a..83fcf1ee7 100644 --- a/go.sum +++ b/go.sum @@ -62,8 +62,6 @@ github.com/ProtonMail/go-crypto v0.0.0-20210428141323-04723f9f07d7/go.mod h1:z4/ github.com/Shopify/goreferrer v0.0.0-20181106222321-ec9c9a553398/go.mod h1:a1uqRtAwp2Xwc6WNPJEufxJ7fx3npB4UV/JOLmbu5I0= github.com/acomagu/bufpipe v1.0.3 h1:fxAGrHZTgQ9w5QqVItgzwj235/uYZYgbXitB+dLupOk= github.com/acomagu/bufpipe v1.0.3/go.mod h1:mxdxdup/WdsKVreO5GpW4+M/1CE2sMG4jeGJ2sYmHc4= -github.com/adrg/xdg v0.4.0 h1:RzRqFcjH4nE5C6oTAxhBtoE2IRyjBSa62SCbyPidvls= -github.com/adrg/xdg v0.4.0/go.mod h1:N6ag73EX4wyxeaoeHctc1mas01KZgsj5tYiAIwqJE/E= github.com/ajg/form v1.5.1/go.mod h1:uL1WgH+h2mgNtvBq0339dVnzXdBETtL2LeUXaIv25UY= github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc= github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= diff --git a/pkg/home/cache.go b/pkg/home/cache.go index cb5ea5eaa..d11daed25 100644 --- a/pkg/home/cache.go +++ b/pkg/home/cache.go @@ -17,11 +17,11 @@ package home import ( "encoding/gob" "os" - "path/filepath" - "github.com/adrg/xdg" "github.com/cockroachdb/errors" "github.com/sirupsen/logrus" + + "github.com/tensorchord/envd/pkg/util/fileutil" ) type cacheManager interface { @@ -32,14 +32,14 @@ type cacheManager interface { } func (m *generalManager) initCache() error { - // Create $XDG_CACHE_HOME/envd - _, err := xdg.CacheFile("envd/cache") + // Create $HOME/.cache/envd/ + m.cacheDir = fileutil.DefaultCacheDir + + cacheStatusFile, err := fileutil.CacheFile("cache.status") if err != nil { - return errors.Wrap(err, "failed to get cache") + return errors.Wrap(err, "failed to get cache.status file path") } - m.cacheDir = filepath.Join(xdg.CacheHome, "envd") - - m.cacheStatusFile = filepath.Join(m.cacheDir, "cache.status") + m.cacheStatusFile = cacheStatusFile _, err = os.Stat(m.cacheStatusFile) if err != nil { if os.IsNotExist(err) { diff --git a/pkg/home/config.go b/pkg/home/config.go index c277c6bf5..cacf4c4db 100644 --- a/pkg/home/config.go +++ b/pkg/home/config.go @@ -15,7 +15,6 @@ package home import ( - "github.com/adrg/xdg" "github.com/cockroachdb/errors" "github.com/tensorchord/envd/pkg/util/fileutil" @@ -26,8 +25,8 @@ type configManager interface { } func (m *generalManager) initConfig() error { - // Create $XDG_CONFIG_HOME/envd/config.envd - config, err := xdg.ConfigFile("envd/config.envd") + // Create $HOME/.config/envd/config.envd + config, err := fileutil.ConfigFile("config.envd") if err != nil { return errors.Wrap(err, "failed to get config file") } diff --git a/pkg/home/context.go b/pkg/home/context.go index d0f3fa42f..4df702c1c 100644 --- a/pkg/home/context.go +++ b/pkg/home/context.go @@ -18,11 +18,11 @@ import ( "encoding/gob" "os" - "github.com/adrg/xdg" "github.com/cockroachdb/errors" "github.com/sirupsen/logrus" "github.com/tensorchord/envd/pkg/types" + "github.com/tensorchord/envd/pkg/util/fileutil" ) type contextManager interface { @@ -36,11 +36,11 @@ type contextManager interface { } func (m *generalManager) initContext() error { - contextfile, err := xdg.ConfigFile("envd/contexts") + contextFile, err := fileutil.ConfigFile("contexts") if err != nil { return errors.Wrap(err, "failed to get context file") } - m.contextFile = contextfile + m.contextFile = contextFile // Create default context. diff --git a/pkg/home/manager_test.go b/pkg/home/manager_test.go index f469f1a2c..6393eb075 100644 --- a/pkg/home/manager_test.go +++ b/pkg/home/manager_test.go @@ -18,11 +18,11 @@ import ( "os" "path/filepath" - "github.com/adrg/xdg" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/tensorchord/envd/pkg/types" + "github.com/tensorchord/envd/pkg/util/fileutil" ) var _ = Describe("home manager", func() { @@ -43,16 +43,16 @@ var _ = Describe("home manager", func() { } Expect(defaultManager.init()).NotTo(HaveOccurred()) m := GetManager() - Expect(m.CacheDir()).To(Equal(filepath.Join(xdg.CacheHome, "envd"))) - Expect(m.ConfigFile()).To(Equal(filepath.Join(xdg.ConfigHome, "envd/config.envd"))) - Expect(m.ContextFile()).To(Equal(filepath.Join(xdg.ConfigHome, "envd/contexts"))) + Expect(m.CacheDir()).To(Equal(filepath.Join(fileutil.DefaultCacheDir))) + Expect(m.ConfigFile()).To(Equal(filepath.Join(fileutil.DefaultConfigDir, "config.envd"))) + Expect(m.ContextFile()).To(Equal(filepath.Join(fileutil.DefaultConfigDir, "contexts"))) driver, socket, err := m.ContextGetCurrent() Expect(err).NotTo(HaveOccurred()) Expect(driver).To(Equal(types.BuilderTypeDocker)) Expect(socket).To(Equal("envd_buildkitd")) }) It("should return the cache status", func() { - Expect(os.RemoveAll(filepath.Join(xdg.CacheHome, "envd/cache.status"))).NotTo(HaveOccurred()) + Expect(os.RemoveAll(filepath.Join(fileutil.DefaultCacheDir, "cache.status"))).NotTo(HaveOccurred()) Expect(Initialize()).NotTo(HaveOccurred()) m := GetManager() m.(*generalManager).cacheMap = make(map[string]bool) diff --git a/pkg/shell/zsh_test.go b/pkg/shell/zsh_test.go index 84ffcc8d0..7f1b297d1 100644 --- a/pkg/shell/zsh_test.go +++ b/pkg/shell/zsh_test.go @@ -18,7 +18,6 @@ import ( "os" "path/filepath" - "github.com/adrg/xdg" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" @@ -34,7 +33,7 @@ var _ = Describe("zsh manager", Serial, func() { AfterEach(func() { // Cleanup the home cache. Expect(home.Initialize()).NotTo(HaveOccurred()) - Expect(os.RemoveAll(filepath.Join(xdg.CacheHome, "envd/cache.status"))).NotTo(HaveOccurred()) + Expect(os.RemoveAll(filepath.Join(fileutil.DefaultCacheDir, "cache.status"))).NotTo(HaveOccurred()) }) When("cached", func() { It("should skip", func() { diff --git a/pkg/ssh/config/key.go b/pkg/ssh/config/key.go index cab58ef39..bc6a3627d 100644 --- a/pkg/ssh/config/key.go +++ b/pkg/ssh/config/key.go @@ -21,7 +21,6 @@ import ( "encoding/pem" "os" - "github.com/adrg/xdg" "github.com/cockroachdb/errors" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" @@ -135,12 +134,12 @@ func generatePublicKey(privatekey *rsa.PublicKey) ([]byte, error) { } func getDefaultKeyPaths() (string, string, error) { - public, err := xdg.ConfigFile("envd/" + config.PublicKeyFile) + public, err := fileutil.ConfigFile(config.PublicKeyFile) if err != nil { return "", "", errors.Wrap(err, "Cannot get public key path") } - private, err := xdg.ConfigFile("envd/" + config.PrivateKeyFile) + private, err := fileutil.ConfigFile(config.PrivateKeyFile) if err != nil { return "", "", errors.Wrap(err, "Cannot get private key path") } diff --git a/pkg/util/fileutil/file.go b/pkg/util/fileutil/file.go index 86dbb351c..b7e7ba7a6 100644 --- a/pkg/util/fileutil/file.go +++ b/pkg/util/fileutil/file.go @@ -17,13 +17,30 @@ package fileutil import ( + "fmt" "os" + "path" "path/filepath" + "strings" "github.com/cockroachdb/errors" "github.com/sirupsen/logrus" ) +var ( + DefaultConfigDir string + DefaultCacheDir string +) + +func init() { + home, err := os.UserHomeDir() + if err != nil { + panic(err) + } + DefaultConfigDir = path.Join(home, ".config/envd") + DefaultCacheDir = path.Join(home, ".cache/envd") +} + // FileExists returns true if the file exists func FileExists(filename string) (bool, error) { info, err := os.Stat(filename) @@ -82,3 +99,39 @@ func RootDir() (string, error) { func Base(dir string) string { return filepath.Base(dir) } + +func MkdirIfNotExist(filepath string) error { + exist, err := DirExists(filepath) + if err != nil { + return err + } + if !exist { + err = os.MkdirAll(filepath, os.ModeDir|0700) + if err != nil { + return errors.Wrap(err, "failed to create the dir") + } + } + return nil +} + +// ConfigFile returns the location for the specified envd config file +func ConfigFile(filename string) (string, error) { + if strings.ContainsRune(filename, os.PathSeparator) { + return "", fmt.Errorf("filename %s should not contain any path separator", filename) + } + if err := MkdirIfNotExist(DefaultConfigDir); err != nil { + return "", err + } + return path.Join(DefaultConfigDir, filename), nil +} + +// CacheFile returns the location for the specified envd cache file +func CacheFile(filename string) (string, error) { + if strings.ContainsRune(filename, os.PathSeparator) { + return "", fmt.Errorf("filename %s should not contain any path separator", filename) + } + if err := MkdirIfNotExist(DefaultCacheDir); err != nil { + return "", err + } + return path.Join(DefaultCacheDir, filename), nil +}