diff --git a/cmd/root.go b/cmd/root.go index eef5f5b9..8bbf8d0b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -28,6 +28,7 @@ import ( pkgutil "github.com/GoogleContainerTools/container-diff/pkg/util" "github.com/GoogleContainerTools/container-diff/util" homedir "github.com/mitchellh/go-homedir" + "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/spf13/cobra" "github.com/spf13/pflag" @@ -39,9 +40,12 @@ var save bool var types diffTypes var noCache bool +var cacheDir string var LogLevel string var format string +const containerDiffEnvCacheDir = "CONTAINER_DIFF_CACHEDIR" + type validatefxn func(args []string) error var RootCmd = &cobra.Command{ @@ -129,7 +133,7 @@ func getImage(imageName string) (pkgutil.Image, error) { var cachePath string var err error if !noCache { - cachePath, err = cacheDir(imageName) + cachePath, err = getCacheDir(imageName) if err != nil { return pkgutil.Image{}, err } @@ -137,12 +141,23 @@ func getImage(imageName string) (pkgutil.Image, error) { return pkgutil.GetImage(imageName, includeLayers(), cachePath) } -func cacheDir(imageName string) (string, error) { - dir, err := homedir.Dir() - if err != nil { - return "", err +func getCacheDir(imageName string) (string, error) { + // First preference for cache is set at command line + if cacheDir == "" { + // second preference is environment + cacheDir = os.Getenv(containerDiffEnvCacheDir) } - rootDir := filepath.Join(dir, ".container-diff", "cache") + + // Third preference (default) is set at $HOME + if cacheDir == "" { + dir, err := homedir.Dir() + if err != nil { + return "", errors.Wrap(err, "retrieving home dir") + } else { + cacheDir = dir + } + } + rootDir := filepath.Join(cacheDir, ".container-diff", "cache") imageName = strings.Replace(imageName, string(os.PathSeparator), "", -1) return filepath.Join(rootDir, filepath.Clean(imageName)), nil } @@ -185,4 +200,6 @@ func addSharedFlags(cmd *cobra.Command) { cmd.Flags().BoolVarP(&save, "save", "s", false, "Set this flag to save rather than remove the final image filesystems on exit.") cmd.Flags().BoolVarP(&util.SortSize, "order", "o", false, "Set this flag to sort any file/package results by descending size. Otherwise, they will be sorted by name.") cmd.Flags().BoolVarP(&noCache, "no-cache", "n", false, "Set this to force retrieval of image filesystem on each run.") + cmd.Flags().StringVarP(&cacheDir, "cache-dir", "c", "", "cache directory base to create .container-diff (default is $HOME).") + } diff --git a/cmd/root_test.go b/cmd/root_test.go index 319f7315..5ffa9a18 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -16,7 +16,81 @@ limitations under the License. package cmd +import ( + "os" + "path" + "path/filepath" + "testing" + + homedir "github.com/mitchellh/go-homedir" +) + type testpair struct { input []string shouldError bool } + +func TestCacheDir(t *testing.T) { + homeDir, err := homedir.Dir() + if err != nil { + t.Errorf("error getting home dir: %s", err.Error()) + } + tests := []struct { + name string + cliFlag string + envVar string + expectedDir string + imageName string + }{ + { + name: "default cache is at $HOME", + cliFlag: "", + envVar: "", + expectedDir: filepath.Join(homeDir, ".container-diff", "cache"), + imageName: "pancakes", + }, + { + name: "setting cache via --cache-dir", + cliFlag: "/tmp", + envVar: "", + expectedDir: "/tmp/.container-diff/cache", + imageName: "pancakes", + }, + { + name: "setting cache via CONTAINER_DIFF_CACHEDIR", + cliFlag: "", + envVar: "/tmp", + expectedDir: "/tmp/.container-diff/cache", + imageName: "pancakes", + }, + { + name: "command line --cache-dir takes preference to CONTAINER_DIFF_CACHEDIR", + cliFlag: "/tmp", + envVar: "/opt", + expectedDir: "/tmp/.container-diff/cache", + imageName: "pancakes", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // set any environment variables + if tt.envVar != "" { + os.Setenv("CONTAINER_DIFF_CACHEDIR", tt.envVar) + } + // Set global flag for cache based on --cache-dir + cacheDir = tt.cliFlag + + // call getCacheDir and make sure return is equal to expected + actualDir, err := getCacheDir(tt.imageName) + if err != nil { + t.Errorf("Error getting cache dir %s: %s", tt.name, err.Error()) + } + + if path.Dir(actualDir) != tt.expectedDir { + t.Errorf("%s\nexpected: %v\ngot: %v", tt.name, tt.expectedDir, actualDir) + } + }, + ) + } +}