Skip to content

Commit

Permalink
Merge pull request #707 from lightninglabs/migrations-tests
Browse files Browse the repository at this point in the history
tapdb: add minimalistic test framework for testing DB migrations
  • Loading branch information
Roasbeef authored Dec 9, 2023
2 parents d763d68 + efeb01c commit bb38de7
Show file tree
Hide file tree
Showing 12 changed files with 369 additions and 69 deletions.
12 changes: 12 additions & 0 deletions internal/test/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package test
import (
"bytes"
"encoding/hex"
"os"
"path/filepath"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -369,3 +371,13 @@ func ScriptSchnorrSig(t *testing.T, pubKey *btcec.PublicKey) txscript.TapLeaf {
require.NoError(t, err)
return txscript.NewBaseTapLeaf(script2)
}

// ReadTestDataFile reads a file from the testdata directory and returns its
// content as a string.
func ReadTestDataFile(t *testing.T, fileName string) string {
path := filepath.Join("testdata", fileName)
fileBytes, err := os.ReadFile(path)
require.NoError(t, err)

return string(fileBytes)
}
9 changes: 8 additions & 1 deletion tapdb/asset_minting_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ func newAssetStore(t *testing.T) (*AssetMintingStore, *AssetStore,
// First, Make a new test database.
db := NewTestDB(t)

mintStore, assetStore := newAssetStoreFromDB(db.BaseDB)
return mintStore, assetStore, db
}

// newAssetStoreFromDB makes a new instance of the AssetMintingStore backed by
// the passed database.
func newAssetStoreFromDB(db *BaseDB) (*AssetMintingStore, *AssetStore) {
// TODO(roasbeef): can use another layer of type params since
// duplicated?
txCreator := func(tx *sql.Tx) PendingAssetStore {
Expand All @@ -50,7 +57,7 @@ func newAssetStore(t *testing.T) (*AssetMintingStore, *AssetStore,
testClock := clock.NewTestClock(time.Now())

return NewAssetMintingStore(assetMintingDB),
NewAssetStore(assetsDB, testClock), db
NewAssetStore(assetsDB, testClock)
}

func assertBatchState(t *testing.T, batch *tapgarden.MintingBatch,
Expand Down
2 changes: 1 addition & 1 deletion tapdb/assets_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ func fetchAssetsWithWitness(ctx context.Context, q ActiveAssetsStore,
// First, we'll fetch all the assets we know of on disk.
dbAssets, err := q.QueryAssets(ctx, assetFilter)
if err != nil {
return nil, nil, fmt.Errorf("unable to read db assets: %v", err)
return nil, nil, fmt.Errorf("unable to read db assets: %w", err)
}

assetIDs := fMap(dbAssets, func(a ConfirmedAsset) int64 {
Expand Down
35 changes: 29 additions & 6 deletions tapdb/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package tapdb

import (
"bytes"
"errors"
"io"
"io/fs"
"net/http"
Expand All @@ -12,11 +13,31 @@ import (
"github.com/golang-migrate/migrate/v4/source/httpfs"
)

// applyMigrations executes all database migration files found in the given file
// MigrationTarget is a functional option that can be passed to applyMigrations
// to specify a target version to migrate to.
type MigrationTarget func(mig *migrate.Migrate) error

var (
// TargetLatest is a MigrationTarget that migrates to the latest
// version available.
TargetLatest = func(mig *migrate.Migrate) error {
return mig.Up()
}

// TargetVersion is a MigrationTarget that migrates to the given
// version.
TargetVersion = func(version uint) MigrationTarget {
return func(mig *migrate.Migrate) error {
return mig.Migrate(version)
}
}
)

// applyMigrations executes database migration files found in the given file
// system under the given path, using the passed database driver and database
// name.
func applyMigrations(fs fs.FS, driver database.Driver, path,
dbName string) error {
// name, up to or down to the given target version.
func applyMigrations(fs fs.FS, driver database.Driver, path, dbName string,
targetVersion MigrationTarget) error {

// With the migrate instance open, we'll create a new migration source
// using the embedded file system stored in sqlSchemas. The library
Expand All @@ -36,8 +57,10 @@ func applyMigrations(fs fs.FS, driver database.Driver, path,
if err != nil {
return err
}
err = sqlMigrate.Up()
if err != nil && err != migrate.ErrNoChange {

// Execute the migration based on the target given.
err = targetVersion(sqlMigrate)
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
return err
}

Expand Down
53 changes: 53 additions & 0 deletions tapdb/migrations_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package tapdb

import (
"context"
"testing"

"github.com/stretchr/testify/require"
)

// TestMigrationSteps is an example test that illustrates how to test database
// migrations by selectively applying only some migrations, inserting dummy data
// and then applying the remaining migrations.
func TestMigrationSteps(t *testing.T) {
ctx := context.Background()

// As a first step, we create a new database but only migrate to
// version 1, which only contains the macaroon tables.
db := NewTestDBWithVersion(t, 1)

// If we create an assets store now, there should be no tables for the
// assets yet.
_, assetStore := newAssetStoreFromDB(db.BaseDB)
_, err := assetStore.FetchAllAssets(ctx, true, true, nil)
require.True(t, IsSchemaError(MapSQLError(err)))

// We now migrate to a later but not yet latest version.
err = db.ExecuteMigrations(TargetVersion(11))
require.NoError(t, err)

// Now there should be an asset table.
_, err = assetStore.FetchAllAssets(ctx, true, true, nil)
require.NoError(t, err)

// Assuming the next version does some changes to the data within the
// asset table, we now add some dummy data to the assets related tables,
// so we could then test that migration.
InsertTestdata(t, db.BaseDB, "migrations_test_00011_dummy_data.sql")

// Make sure we now have actual assets in the database.
dbAssets, err := assetStore.FetchAllAssets(ctx, true, true, nil)
require.NoError(t, err)
require.Len(t, dbAssets, 4)

// And now that we have test data inserted, we can migrate to the latest
// version.
err = db.ExecuteMigrations(TargetLatest)
require.NoError(t, err)

// Here we would now test that the migration to the latest version did
// what we expected it to do. But this is just an example, illustrating
// the steps that can be taken to test migrations, so we are done for
// this test.
}
95 changes: 63 additions & 32 deletions tapdb/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ var (
// fully executed yet. So this time needs to be chosen correctly to be
// longer than the longest expected individual test run time.
DefaultPostgresFixtureLifetime = 60 * time.Minute

// postgresSchemaReplacements is a map of schema strings that need to be
// replaced for postgres. This is needed because we write the schemas
// to work with sqlite primarily, and postgres has some differences.
postgresSchemaReplacements = map[string]string{
"BLOB": "BYTEA",
"INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY",
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
}
)

// PostgresConfig holds the postgres database configuration.
Expand Down Expand Up @@ -107,44 +117,41 @@ func NewPostgresStore(cfg *PostgresConfig) (*PostgresStore, error) {
rawDb.SetConnMaxLifetime(connMaxLifetime)
rawDb.SetConnMaxIdleTime(connMaxIdleTime)

if !cfg.SkipMigrations {
// Now that the database is open, populate the database with
// our set of schemas based on our embedded in-memory file
// system.
//
// First, we'll need to open up a new migration instance for
// our current target database: sqlite.
driver, err := postgres_migrate.WithInstance(
rawDb, &postgres_migrate.Config{},
)
if err != nil {
return nil, err
}

postgresFS := newReplacerFS(sqlSchemas, map[string]string{
"BLOB": "BYTEA",
"INTEGER PRIMARY KEY": "SERIAL PRIMARY KEY",
"BIGINT PRIMARY KEY": "BIGSERIAL PRIMARY KEY",
"TIMESTAMP": "TIMESTAMP WITHOUT TIME ZONE",
})

err = applyMigrations(
postgresFS, driver, "sqlc/migrations", cfg.DBName,
)
if err != nil {
return nil, err
}
}

queries := sqlc.NewPostgres(rawDb)

return &PostgresStore{
s := &PostgresStore{
cfg: cfg,
BaseDB: &BaseDB{
DB: rawDb,
Queries: queries,
},
}, nil
}

// Now that the database is open, populate the database with our set of
// schemas based on our embedded in-memory file system.
if !cfg.SkipMigrations {
if err := s.ExecuteMigrations(TargetLatest); err != nil {
return nil, fmt.Errorf("error executing migrations: "+
"%w", err)
}
}

return s, nil
}

// ExecuteMigrations runs migrations for the Postgres database, depending on the
// target given, either all migrations or up to a given version.
func (s *PostgresStore) ExecuteMigrations(target MigrationTarget) error {
driver, err := postgres_migrate.WithInstance(
s.DB, &postgres_migrate.Config{},
)
if err != nil {
return fmt.Errorf("error creating postgres migration: %w", err)
}

postgresFS := newReplacerFS(sqlSchemas, postgresSchemaReplacements)
return applyMigrations(
postgresFS, driver, "sqlc/migrations", s.cfg.DBName, target,
)
}

// NewTestPostgresDB is a helper function that creates a Postgres database for
Expand All @@ -164,3 +171,27 @@ func NewTestPostgresDB(t *testing.T) *PostgresStore {

return store
}

// NewTestPostgresDBWithVersion is a helper function that creates a Postgres
// database for testing and migrates it to the given version.
func NewTestPostgresDBWithVersion(t *testing.T, version uint) *PostgresStore {
t.Helper()

t.Logf("Creating new Postgres DB for testing, migrating to version %d",
version)

sqlFixture := NewTestPgFixture(t, DefaultPostgresFixtureLifetime, true)
storeCfg := sqlFixture.GetConfig()
storeCfg.SkipMigrations = true
store, err := NewPostgresStore(storeCfg)
require.NoError(t, err)

err = store.ExecuteMigrations(TargetVersion(version))
require.NoError(t, err)

t.Cleanup(func() {
sqlFixture.TearDown(t)
})

return store
}
43 changes: 43 additions & 0 deletions tapdb/sqlerrors.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package tapdb
import (
"errors"
"fmt"
"strings"

"github.com/jackc/pgconn"
"github.com/jackc/pgerrcode"
Expand Down Expand Up @@ -52,6 +53,20 @@ func parseSqliteError(sqliteErr *sqlite.Error) error {
DbError: sqliteErr,
}

// Generic error, need to parse the message further.
case sqlite3.SQLITE_ERROR:
errMsg := sqliteErr.Error()

switch {
case strings.Contains(errMsg, "no such table"):
return &ErrSchemaError{
DbError: sqliteErr,
}

default:
return fmt.Errorf("unknown sqlite error: %w", sqliteErr)
}

default:
return fmt.Errorf("unknown sqlite error: %w", sqliteErr)
}
Expand All @@ -73,6 +88,12 @@ func parsePostgresError(pqErr *pgconn.PgError) error {
DbError: pqErr,
}

// Handle schema error.
case pgerrcode.UndefinedColumn, pgerrcode.UndefinedTable:
return &ErrSchemaError{
DbError: pqErr,
}

default:
return fmt.Errorf("unknown postgres error: %w", pqErr)
}
Expand Down Expand Up @@ -111,3 +132,25 @@ func IsSerializationError(err error) bool {
var serializationError *ErrSerializationError
return errors.As(err, &serializationError)
}

// ErrSchemaError is an error type which represents a database agnostic error
// that the schema of the database is incorrect for the given query.
type ErrSchemaError struct {
DbError error
}

// Unwrap returns the wrapped error.
func (e ErrSchemaError) Unwrap() error {
return e.DbError
}

// Error returns the error message.
func (e ErrSchemaError) Error() string {
return e.DbError.Error()
}

// IsSchemaError returns true if the given error is a schema error.
func IsSchemaError(err error) bool {
var schemaError *ErrSchemaError
return errors.As(err, &schemaError)
}
Loading

0 comments on commit bb38de7

Please sign in to comment.