diff --git a/README.md b/README.md index f8fe234..a3c70c6 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,7 @@ This library aims to require as little configuration as possible, favouring over | Password | postgres | | Database | postgres | | Version | 12.1.0 | +| CachePath | $USER_HOME/.embedded-postgres-go/ | | RuntimePath | $USER_HOME/.embedded-postgres-go/extracted | | DataPath | $USER_HOME/.embedded-postgres-go/extracted/data | | BinariesPath | $USER_HOME/.embedded-postgres-go/extracted | diff --git a/cache_locator.go b/cache_locator.go index c1baa50..4169f1b 100644 --- a/cache_locator.go +++ b/cache_locator.go @@ -10,11 +10,13 @@ import ( // The result of whether this cache is present will be returned to exists. type CacheLocator func() (location string, exists bool) -func defaultCacheLocator(versionStrategy VersionStrategy) CacheLocator { +func defaultCacheLocator(cacheDirectory string, versionStrategy VersionStrategy) CacheLocator { return func() (string, bool) { - cacheDirectory := ".embedded-postgres-go" - if userHome, err := os.UserHomeDir(); err == nil { - cacheDirectory = filepath.Join(userHome, ".embedded-postgres-go") + if cacheDirectory == "" { + cacheDirectory = ".embedded-postgres-go" + if userHome, err := os.UserHomeDir(); err == nil { + cacheDirectory = filepath.Join(userHome, ".embedded-postgres-go") + } } operatingSystem, architecture, version := versionStrategy() diff --git a/cache_locator_test.go b/cache_locator_test.go index b40061b..628666f 100644 --- a/cache_locator_test.go +++ b/cache_locator_test.go @@ -7,7 +7,7 @@ import ( ) func Test_defaultCacheLocator_NotExists(t *testing.T) { - locator := defaultCacheLocator(func() (string, string, PostgresVersion) { + locator := defaultCacheLocator("", func() (string, string, PostgresVersion) { return "a", "b", "1.2.3" }) @@ -16,3 +16,14 @@ func Test_defaultCacheLocator_NotExists(t *testing.T) { assert.Contains(t, cacheLocation, ".embedded-postgres-go/embedded-postgres-binaries-a-b-1.2.3.txz") assert.False(t, exists) } + +func Test_defaultCacheLocator_CustomPath(t *testing.T) { + locator := defaultCacheLocator("/custom/path", func() (string, string, PostgresVersion) { + return "a", "b", "1.2.3" + }) + + cacheLocation, exists := locator() + + assert.Equal(t, cacheLocation, "/custom/path/embedded-postgres-binaries-a-b-1.2.3.txz") + assert.False(t, exists) +} diff --git a/config.go b/config.go index 0c8742a..8f73a33 100644 --- a/config.go +++ b/config.go @@ -14,6 +14,7 @@ type Config struct { database string username string password string + cachePath string runtimePath string dataPath string binariesPath string @@ -82,6 +83,13 @@ func (c Config) RuntimePath(path string) Config { return c } +// CachePath sets the path that will be used for storing Postgres binaries archive. +// If this option is not set, ~/.go-embedded-postgres will be used. +func (c Config) CachePath(path string) Config { + c.cachePath = path + return c +} + // DataPath sets the path that will be used for the Postgres data directory. // If this option is set, a previously initialized data directory will be reused if possible. func (c Config) DataPath(path string) Config { diff --git a/embedded_postgres.go b/embedded_postgres.go index 9fa31a8..d3af5f6 100644 --- a/embedded_postgres.go +++ b/embedded_postgres.go @@ -44,7 +44,7 @@ func newDatabaseWithConfig(config Config) *EmbeddedPostgres { linuxMachineName, shouldUseAlpineLinuxBuild, ) - cacheLocator := defaultCacheLocator(versionStrategy) + cacheLocator := defaultCacheLocator(config.cachePath, versionStrategy) remoteFetchStrategy := defaultRemoteFetchStrategy(config.binaryRepositoryURL, versionStrategy, cacheLocator) return &EmbeddedPostgres{ diff --git a/embedded_postgres_test.go b/embedded_postgres_test.go index 549fa13..474cd2a 100644 --- a/embedded_postgres_test.go +++ b/embedded_postgres_test.go @@ -620,6 +620,30 @@ func Test_CustomBinariesRepo(t *testing.T) { } } +func Test_CachePath(t *testing.T) { + cacheTempDir, err := os.MkdirTemp("", "prepare_database_test_cache") + if err != nil { + panic(err) + } + + defer func() { + if err := os.RemoveAll(cacheTempDir); err != nil { + panic(err) + } + }() + + database := NewDatabase(DefaultConfig(). + CachePath(cacheTempDir)) + + if err := database.Start(); err != nil { + shutdownDBAndFail(t, err, database) + } + + if err := database.Stop(); err != nil { + shutdownDBAndFail(t, err, database) + } +} + func Test_CustomBinariesLocation(t *testing.T) { tempDir, err := os.MkdirTemp("", "prepare_database_test") if err != nil {