Skip to content

Commit

Permalink
Use on-disk SQLite databases to allow regular access methods, see mat…
Browse files Browse the repository at this point in the history
  • Loading branch information
jmalloc committed Mar 26, 2020
1 parent dd5e06e commit 471fc28
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 29 deletions.
4 changes: 2 additions & 2 deletions cmd/bank/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ func run(ctx context.Context) error {
}
appName := os.Args[1]

db := sqltest.Open("postgres")
defer db.Close()
db, _, close := sqltest.Open("postgres")
defer close()

// TODO: return to using SQL
// if err := postgres.CreateSchema(ctx, db); err != nil {
Expand Down
6 changes: 3 additions & 3 deletions internal/testing/boltdbtest/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
// The returned function must be used to close the database, instead of
// DB.Close().
func Open() (*bbolt.DB, func()) {
filename, remove := TempFile()
filename, remove := tempFile()

db, err := bbolt.Open(filename, 0600, nil)
if err != nil {
Expand All @@ -26,11 +26,11 @@ func Open() (*bbolt.DB, func()) {
}
}

// TempFile returns the name of a temporary file to be used for a BoltDB
// tempFile returns the name of a temporary file to be used for a BoltDB
// database.
//
// It returns a function that deletes the temporary file.
func TempFile() (string, func()) {
func tempFile() (string, func()) {
f, err := ioutil.TempFile("", "*.boltdb")
if err != nil {
panic(err)
Expand Down
96 changes: 72 additions & 24 deletions internal/testing/sqltest/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,50 +3,98 @@ package sqltest
import (
"database/sql"
"fmt"
"io/ioutil"
"os"
"sync"

_ "github.com/go-sql-driver/mysql" // keep driver import near code that uses it
"github.com/google/uuid"
_ "github.com/lib/pq" // keep driver import near code that uses it
_ "github.com/mattn/go-sqlite3" // keep driver import near code that uses it
_ "github.com/lib/pq" // keep driver import near code that uses it
_ "github.com/mattn/go-sqlite3" // keep driver import near code that uses it
)

// DSN returns the DSN for the test database to use with the given SQL driver.
func DSN(driver string) string {
var env, dsn string
//
// The returned function must be used to cleanup any data created for the DSN,
// such as temporary on-disk databases.
func DSN(driver string) (string, func()) {
dsn := dsnFromEnv(driver)
if dsn != "" {
return dsn, func() {}
}

switch driver {
case "mysql":
env = "DOGMATIQ_TEST_MYSQL_DSN"
dsn = "root:rootpass@tcp(127.0.0.1:3306)/dogmatiq"
case "sqlite3":
env = "DOGMATIQ_TEST_SQLITE_DSN"
dsn = fmt.Sprintf(
"file:sqlite-%s.db?cache=shared&mode=memory",
uuid.New().String(),
)
return "root:rootpass@tcp(127.0.0.1:3306)/dogmatiq", func() {}
case "postgres":
env = "DOGMATIQ_TEST_POSTGRES_DSN"
dsn = "user=postgres password=rootpass sslmode=disable"
return "user=postgres password=rootpass sslmode=disable", func() {}
default:
panic("unsupported driver: " + driver)
file, close := tempFile()
return fmt.Sprintf("file:%s?mode=rwc", file), close
}
}

if v := os.Getenv(env); v != "" {
return v
// Open returns the test database to use with the given driver.
//
// The returned function must be used to close the database, instead of
// DB.Close().
func Open(
driver string,
) (
db *sql.DB,
dsn string,
close func(),
) {
dsn, closeDSN := DSN(driver)

db, err := sql.Open(driver, dsn)
if err != nil {
panic(err)
}

return dsn
return db, dsn, func() {
db.Close()
closeDSN()
}
}

// Open returns the test database to use with the given driver.
func Open(driver string) *sql.DB {
dsn := DSN(driver)
// dsnFromEnv returns a DSN for the given driver from an environment variable.
func dsnFromEnv(driver string) string {
switch driver {
case "mysql":
return os.Getenv("DOGMATIQ_TEST_MYSQL_DSN")
case "postgres":
return os.Getenv("DOGMATIQ_TEST_POSTGRES_DSN")
case "sqlite3":
return os.Getenv("DOGMATIQ_TEST_SQLITE_DSN")
default:
panic("unsupported driver: " + driver)
}
}

db, err := sql.Open(driver, dsn)
// tempFile returns the name of a temporary file to be used for an SQLite
// database.
//
// It returns a function that deletes the temporary file.
func tempFile() (string, func()) {
f, err := ioutil.TempFile("", "*.sqlite3")
if err != nil {
panic(err)
}

return db
if err := f.Close(); err != nil {
panic(err)
}

file := f.Name()

if err := os.Remove(file); err != nil {
panic(err)
}

var once sync.Once
return file, func() {
once.Do(func() {
os.Remove(file)
})
}
}

0 comments on commit 471fc28

Please sign in to comment.