diff --git a/CHANGELOG.md b/CHANGELOG.md index 9742f17d..5fcaa0db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [0.0.10] - 2023-11-26 + +### Added + +- Added `river/rivermigrate` package to enable migrations from Go code as an alternative to using the CLI. + ## [0.0.9] - 2023-11-23 ### Fixed diff --git a/cmd/river/main.go b/cmd/river/main.go index 0007ff22..006af9a5 100644 --- a/cmd/river/main.go +++ b/cmd/river/main.go @@ -3,7 +3,6 @@ package main import ( "context" "fmt" - "log/slog" "os" "strconv" "time" @@ -11,9 +10,8 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/spf13/cobra" - "github.com/riverqueue/river/internal/baseservice" - "github.com/riverqueue/river/internal/dbmigrate" - "github.com/riverqueue/river/internal/util/slogutil" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivermigrate" ) func main() { @@ -87,7 +85,7 @@ restricted with --max-steps. }, } cmd.Flags().StringVar(&opts.DatabaseURL, "database-url", "", "URL of the database to migrate (should look like `postgres://...`") - cmd.Flags().IntVar(&opts.MaxSteps, "max-steps", -1, "Maximum number of steps to migrate") + cmd.Flags().IntVar(&opts.MaxSteps, "max-steps", 0, "Maximum number of steps to migrate") mustMarkFlagRequired(cmd, "database-url") rootCmd.AddCommand(cmd) } @@ -95,13 +93,6 @@ restricted with --max-steps. execHandlingError(rootCmd.Execute) } -func baseServiceArchetype() *baseservice.Archetype { - return &baseservice.Archetype{ - Logger: slog.New(&slogutil.SlogMessageOnlyHandler{Level: slog.LevelInfo}), - TimeNowUTC: func() time.Time { return time.Now().UTC() }, - } -} - func openDBPool(ctx context.Context, databaseURL string) (*pgxpool.Pool, error) { const ( defaultIdleInTransactionSessionTimeout = 11 * time.Second // should be greater than statement timeout because statements count towards idle-in-transaction @@ -159,10 +150,10 @@ func migrateDown(ctx context.Context, opts *migrateDownOpts) error { } defer dbPool.Close() - migrator := dbmigrate.NewMigrator(baseServiceArchetype()) + migrator := rivermigrate.New(riverpgxv5.New(dbPool), nil) - _, err = migrator.Down(ctx, dbPool, &dbmigrate.MigrateOptions{ - MaxSteps: &opts.MaxSteps, + _, err = migrator.Migrate(ctx, rivermigrate.DirectionDown, &rivermigrate.MigrateOpts{ + MaxSteps: opts.MaxSteps, }) return err } @@ -191,10 +182,10 @@ func migrateUp(ctx context.Context, opts *migrateUpOpts) error { } defer dbPool.Close() - migrator := dbmigrate.NewMigrator(baseServiceArchetype()) + migrator := rivermigrate.New(riverpgxv5.New(dbPool), nil) - _, err = migrator.Up(ctx, dbPool, &dbmigrate.MigrateOptions{ - MaxSteps: &opts.MaxSteps, + _, err = migrator.Migrate(ctx, rivermigrate.DirectionUp, &rivermigrate.MigrateOpts{ + MaxSteps: opts.MaxSteps, }) return err } diff --git a/internal/cmd/testdbman/create.go b/internal/cmd/testdbman/create.go index 09425287..afa190fd 100644 --- a/internal/cmd/testdbman/create.go +++ b/internal/cmd/testdbman/create.go @@ -3,7 +3,6 @@ package main import ( "context" "fmt" - "log/slog" "os" "runtime" "time" @@ -12,8 +11,8 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/spf13/cobra" - "github.com/riverqueue/river/internal/baseservice" - "github.com/riverqueue/river/internal/dbmigrate" + "github.com/riverqueue/river/riverdriver/riverpgxv5" + "github.com/riverqueue/river/rivermigrate" ) func init() { //nolint:gochecknoinits @@ -57,16 +56,9 @@ runtime.NumCPU() (a choice that comes from pgx's default connection pool size). } defer dbPool.Close() - archetype := &baseservice.Archetype{ - Logger: slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ - Level: slog.LevelWarn, - })), - TimeNowUTC: func() time.Time { return time.Now().UTC() }, - } - - migrator := dbmigrate.NewMigrator(archetype) + migrator := rivermigrate.New(riverpgxv5.New(dbPool), nil) - if _, err = migrator.Up(ctx, dbPool, &dbmigrate.MigrateOptions{}); err != nil { + if _, err = migrator.Migrate(ctx, rivermigrate.DirectionUp, &rivermigrate.MigrateOpts{}); err != nil { return err } fmt.Printf("Loaded schema in %s.\n", dbName) diff --git a/internal/dbmigrate/db_migrate.go b/internal/dbmigrate/db_migrate.go deleted file mode 100644 index 5941bcd7..00000000 --- a/internal/dbmigrate/db_migrate.go +++ /dev/null @@ -1,279 +0,0 @@ -package dbmigrate - -import ( - "context" - _ "embed" - "errors" - "fmt" - "log/slog" - "maps" - "slices" - "strings" - - "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - - "github.com/riverqueue/river/internal/baseservice" - "github.com/riverqueue/river/internal/dbsqlc" - "github.com/riverqueue/river/internal/util/dbutil" - "github.com/riverqueue/river/internal/util/maputil" - "github.com/riverqueue/river/internal/util/sliceutil" -) - -var ( - //go:embed 001_create_river_migration.down.sql - sql001CreateRiverMigrationDown string - - //go:embed 001_create_river_migration.up.sql - sql001CreateRiverMigrationUp string - - //go:embed 002_initial_schema.down.sql - sql002InitialSchemaDown string - - //go:embed 002_initial_schema.up.sql - sql002InitialSchemaUp string - - //go:embed 003_river_job_tags_non_null.down.sql - sql003RiverJobTagsNonNullDown string - - //go:embed 003_river_job_tags_non_null.up.sql - sql003RiverJobTagsNonNullUp string -) - -type migrationBundle struct { - Version int64 - Up string - Down string -} - -//nolint:gochecknoglobals -var ( - riverMigrations = []*migrationBundle{ - {Version: 1, Up: sql001CreateRiverMigrationUp, Down: sql001CreateRiverMigrationDown}, - {Version: 2, Up: sql002InitialSchemaUp, Down: sql002InitialSchemaDown}, - {Version: 3, Up: sql003RiverJobTagsNonNullUp, Down: sql003RiverJobTagsNonNullDown}, - } - - riverMigrationsMap = validateAndInit(riverMigrations) -) - -// Migrator is a database migration tool for River which can run up or down -// migrations in order to establish the schema that the queue needs to run. -type Migrator struct { - baseservice.BaseService - - migrations map[int64]*migrationBundle - queries *dbsqlc.Queries -} - -func NewMigrator(archetype *baseservice.Archetype) *Migrator { - return baseservice.Init(archetype, &Migrator{ - migrations: riverMigrationsMap, - queries: dbsqlc.New(), - }) -} - -// MigrateOptions are options for a migrate operation. -type MigrateOptions struct { - // MaxSteps is the maximum number of migrations to apply either up or down. - // Leave nil or set -1 for an unlimited number. - MaxSteps *int -} - -// MigrateResult is the result of a migrate operation. -type MigrateResult struct { - // Versions are migration versions that were added (for up migrations) or - // removed (for down migrations) for this run. - Versions []int64 -} - -// Down runs down migrations. -func (m *Migrator) Down(ctx context.Context, txBeginner dbutil.TxBeginner, opts *MigrateOptions) (*MigrateResult, error) { - return dbutil.WithTxV(ctx, txBeginner, func(ctx context.Context, tx pgx.Tx) (*MigrateResult, error) { - existingMigrations, err := m.existingMigrations(ctx, tx) - if err != nil { - return nil, err - } - existingMigrationsMap := sliceutil.KeyBy(existingMigrations, - func(m *dbsqlc.RiverMigration) (int64, struct{}) { return m.Version, struct{}{} }) - - targetMigrations := maps.Clone(m.migrations) - for version := range targetMigrations { - if _, ok := existingMigrationsMap[version]; !ok { - delete(targetMigrations, version) - } - } - - sortedTargetMigrations := maputil.Values(targetMigrations) - slices.SortFunc(sortedTargetMigrations, func(a, b *migrationBundle) int { return int(b.Version - a.Version) }) // reverse order - - res, err := m.applyMigrations(ctx, tx, opts, sortedTargetMigrations, true) - if err != nil { - return nil, err - } - - // If we did no work, leave early. This allows a zero-migrated database - // that's being no-op downmigrated again to succeed because otherwise - // the delete below would cause it to error. - if len(res.Versions) < 1 { - return res, nil - } - - // Migration version 1 is special-cased because if it was downmigrated - // it means the `river_migration` table is no longer present so there's - // nothing to delete out of. - if slices.Contains(res.Versions, 1) { - return res, nil - } - - if _, err := m.queries.RiverMigrationDeleteByVersionMany(ctx, tx, res.Versions); err != nil { - return nil, fmt.Errorf("error deleting migration rows for versions %+v: %w", res.Versions, err) - } - - return res, nil - }) -} - -// Up runs up migrations. -func (m *Migrator) Up(ctx context.Context, txBeginner dbutil.TxBeginner, opts *MigrateOptions) (*MigrateResult, error) { - return dbutil.WithTxV(ctx, txBeginner, func(ctx context.Context, tx pgx.Tx) (*MigrateResult, error) { - existingMigrations, err := m.existingMigrations(ctx, tx) - if err != nil { - return nil, err - } - - targetMigrations := maps.Clone(m.migrations) - for _, migrateRow := range existingMigrations { - delete(targetMigrations, migrateRow.Version) - } - - sortedTargetMigrations := maputil.Values(targetMigrations) - slices.SortFunc(sortedTargetMigrations, func(a, b *migrationBundle) int { return int(a.Version - b.Version) }) - - res, err := m.applyMigrations(ctx, tx, opts, sortedTargetMigrations, false) - if err != nil { - return nil, err - } - - if _, err := m.queries.RiverMigrationInsertMany(ctx, tx, res.Versions); err != nil { - return nil, fmt.Errorf("error inserting migration rows for versions %+v: %w", res.Versions, err) - } - - return res, nil - }) -} - -// Common code shared between the up and down migration directions that walks -// through each target migration and applies it, logging appropriately. -func (m *Migrator) applyMigrations(ctx context.Context, tx pgx.Tx, opts *MigrateOptions, sortedTargetMigrations []*migrationBundle, down bool) (*MigrateResult, error) { - if opts.MaxSteps != nil && *opts.MaxSteps >= 0 { - sortedTargetMigrations = sortedTargetMigrations[0:min(*opts.MaxSteps, len(sortedTargetMigrations))] - } - - res := &MigrateResult{Versions: make([]int64, 0, len(sortedTargetMigrations))} - - // Short circuit early if there's nothing to do. - if len(sortedTargetMigrations) < 1 { - m.Logger.InfoContext(ctx, m.Name+": No migrations to apply") - return res, nil - } - - direction := "up" - if down { - direction = "down" - } - - for _, versionBundle := range sortedTargetMigrations { - sql := versionBundle.Up - if down { - sql = versionBundle.Down - } - - m.Logger.InfoContext(ctx, fmt.Sprintf(m.Name+": Applying migration %03d [%s]", versionBundle.Version, strings.ToUpper(direction)), - slog.String("direction", direction), - slog.Int64("version", versionBundle.Version), - ) - - _, err := tx.Exec(ctx, sql) - if err != nil { - return nil, fmt.Errorf("error applying version %03d [%s]: %w", - versionBundle.Version, strings.ToUpper(direction), err) - } - - res.Versions = append(res.Versions, versionBundle.Version) - } - - // Only prints if more steps than available were requested. - if opts.MaxSteps != nil && *opts.MaxSteps >= 0 && len(res.Versions) < *opts.MaxSteps { - m.Logger.InfoContext(ctx, m.Name+": No more migrations to apply") - } - - return res, nil -} - -// Get existing migrations that've already been run in the database. This is -// encapsulated to run a check in a subtransaction and the handle the case of -// the `river_migration` table not existing yet. (The subtransaction is needed -// because otherwise the existing transaction would become aborted on an -// unsuccessful `river_migration` check.) -func (m *Migrator) existingMigrations(ctx context.Context, tx pgx.Tx) ([]*dbsqlc.RiverMigration, error) { - // We start another inner transaction here because in case this is the first - // ever migration run, the transaction may become aborted if `river_migration` - // doesn't exist, a condition which we must handle gracefully. - migrations, err := dbutil.WithTxV(ctx, tx, func(ctx context.Context, tx pgx.Tx) ([]*dbsqlc.RiverMigration, error) { - migrations, err := m.queries.RiverMigrationGetAll(ctx, tx) - if err != nil { - return nil, fmt.Errorf("error getting current migrate rows: %w", err) - } - return migrations, nil - }) - if err != nil { - var pgErr *pgconn.PgError - if errors.As(err, &pgErr) { - if pgErr.Code == pgerrcode.UndefinedTable && strings.Contains(pgErr.Message, "river_migration") { - return nil, nil - } - } - - return nil, err - } - - return migrations, nil -} - -// Validates and fully initializes a set of migrations to reduce the probability -// of configuration problems as new migrations are introduced. e.g. Checks for -// missing fields or accidentally duplicated version numbers from copy/pasta -// problems. -func validateAndInit(versions []*migrationBundle) map[int64]*migrationBundle { - lastVersion := int64(0) - migrations := make(map[int64]*migrationBundle, len(versions)) - - for _, versionBundle := range versions { - if versionBundle.Down == "" { - panic(fmt.Sprintf("version bundle should specify Down: %+v", versionBundle)) - } - if versionBundle.Up == "" { - panic(fmt.Sprintf("version bundle should specify Up: %+v", versionBundle)) - } - if versionBundle.Version == 0 { - panic(fmt.Sprintf("version bundle should specify Version: %+v", versionBundle)) - } - - if _, ok := migrations[versionBundle.Version]; ok { - panic(fmt.Sprintf("duplicate version: %03d", versionBundle.Version)) - } - if versionBundle.Version <= lastVersion { - panic(fmt.Sprintf("versions should be ascending; current: %03d, last: %03d", versionBundle.Version, lastVersion)) - } - if versionBundle.Version > lastVersion+1 { - panic(fmt.Sprintf("versions shouldn't skip a sequence number; current: %03d, last: %03d", versionBundle.Version, lastVersion)) - } - - lastVersion = versionBundle.Version - migrations[versionBundle.Version] = versionBundle - } - - return migrations -} diff --git a/internal/dbmigrate/db_migrate_test.go b/internal/dbmigrate/db_migrate_test.go deleted file mode 100644 index f57184b6..00000000 --- a/internal/dbmigrate/db_migrate_test.go +++ /dev/null @@ -1,250 +0,0 @@ -package dbmigrate - -import ( - "context" - "slices" - "testing" - - "github.com/jackc/pgerrcode" - "github.com/jackc/pgx/v5" - "github.com/jackc/pgx/v5/pgconn" - "github.com/stretchr/testify/require" - - "github.com/riverqueue/river/internal/dbsqlc" - "github.com/riverqueue/river/internal/riverinternaltest" - "github.com/riverqueue/river/internal/util/dbutil" - "github.com/riverqueue/river/internal/util/ptrutil" - "github.com/riverqueue/river/internal/util/sliceutil" -) - -//nolint:gochecknoglobals -var ( - // We base our test migrations on the actual line of migrations, so get - // their maximum version number which we'll use to define test version - // numbers so that the tests don't break anytime we add a new one. - riverMigrationsMaxVersion = riverMigrations[len(riverMigrations)-1].Version - - testVersions = validateAndInit(append(riverMigrations, []*migrationBundle{ - { - Version: riverMigrationsMaxVersion + 1, - Up: "CREATE TABLE test_table(id bigserial PRIMARY KEY);", - Down: "DROP TABLE test_table;", - }, - { - Version: riverMigrationsMaxVersion + 2, - Up: "ALTER TABLE test_table ADD COLUMN name varchar(200); CREATE INDEX idx_test_table_name ON test_table(name);", - Down: "DROP INDEX idx_test_table_name; ALTER TABLE test_table DROP COLUMN name;", - }, - }...)) -) - -func TestMigrator(t *testing.T) { - t.Parallel() - - var ( - ctx = context.Background() - queries = dbsqlc.New() - ) - - type testBundle struct { - tx pgx.Tx - } - - setup := func(t *testing.T) (*Migrator, *testBundle) { - t.Helper() - - // The test suite largely works fine with test transactions, bue due to - // the invasive nature of changing schemas, it's quite easy to have test - // transactions deadlock with each other as they run in parallel. Here - // we use test DBs instead of test transactions, but this could be - // changed to test transactions as long as test cases were made to run - // non-parallel. - testDB := riverinternaltest.TestDB(ctx, t) - - // Despite being in an isolated database, we still start a transaction - // because we don't want schema changes we make to persist. - tx, err := testDB.Begin(ctx) - require.NoError(t, err) - t.Cleanup(func() { _ = tx.Rollback(ctx) }) - - bundle := &testBundle{ - tx: tx, - } - - migrator := NewMigrator(riverinternaltest.BaseServiceArchetype(t)) - migrator.migrations = testVersions - - return migrator, bundle - } - - t.Run("Down", func(t *testing.T) { - t.Parallel() - - migrator, bundle := setup(t) - - // Run an initial time - { - res, err := migrator.Down(ctx, bundle.tx, &MigrateOptions{}) - require.NoError(t, err) - require.Equal(t, seqToOne(3), res.Versions) - - err = dbExecError(ctx, bundle.tx, "SELECT * FROM river_migration") - require.Error(t, err) - } - - // Run once more to verify idempotency - { - res, err := migrator.Down(ctx, bundle.tx, &MigrateOptions{}) - require.NoError(t, err) - require.Equal(t, []int64{}, res.Versions) - - err = dbExecError(ctx, bundle.tx, "SELECT * FROM river_migration") - require.Error(t, err) - } - }) - - t.Run("DownAfterUp", func(t *testing.T) { - t.Parallel() - - migrator, bundle := setup(t) - - _, err := migrator.Up(ctx, bundle.tx, &MigrateOptions{}) - require.NoError(t, err) - - res, err := migrator.Down(ctx, bundle.tx, &MigrateOptions{}) - require.NoError(t, err) - require.Equal(t, seqToOne(riverMigrationsMaxVersion+2), res.Versions) - - err = dbExecError(ctx, bundle.tx, "SELECT * FROM river_migration") - require.Error(t, err) - }) - - t.Run("DownWithMaxSteps", func(t *testing.T) { - t.Parallel() - - migrator, bundle := setup(t) - - _, err := migrator.Up(ctx, bundle.tx, &MigrateOptions{}) - require.NoError(t, err) - - res, err := migrator.Down(ctx, bundle.tx, &MigrateOptions{MaxSteps: ptrutil.Ptr(1)}) - require.NoError(t, err) - require.Equal(t, []int64{riverMigrationsMaxVersion + 2}, res.Versions) - - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) - require.NoError(t, err) - require.Equal(t, seqOneTo(riverMigrationsMaxVersion+1), migrationVersions(migrations)) - - err = dbExecError(ctx, bundle.tx, "SELECT name FROM test_table") - require.Error(t, err) - }) - - t.Run("DownWithMaxStepsZero", func(t *testing.T) { - t.Parallel() - - migrator, bundle := setup(t) - - _, err := migrator.Up(ctx, bundle.tx, &MigrateOptions{}) - require.NoError(t, err) - - res, err := migrator.Down(ctx, bundle.tx, &MigrateOptions{MaxSteps: ptrutil.Ptr(0)}) - require.NoError(t, err) - require.Equal(t, []int64{}, res.Versions) - }) - - t.Run("Up", func(t *testing.T) { - t.Parallel() - - migrator, bundle := setup(t) - - // Run an initial time - { - res, err := migrator.Up(ctx, bundle.tx, &MigrateOptions{}) - require.NoError(t, err) - require.Equal(t, []int64{riverMigrationsMaxVersion + 1, riverMigrationsMaxVersion + 2}, res.Versions) - - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) - require.NoError(t, err) - require.Equal(t, seqOneTo(riverMigrationsMaxVersion+2), migrationVersions(migrations)) - - _, err = bundle.tx.Exec(ctx, "SELECT * FROM test_table") - require.NoError(t, err) - } - - // Run once more to verify idempotency - { - res, err := migrator.Up(ctx, bundle.tx, &MigrateOptions{}) - require.NoError(t, err) - require.Equal(t, []int64{}, res.Versions) - - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) - require.NoError(t, err) - require.Equal(t, seqOneTo(riverMigrationsMaxVersion+2), migrationVersions(migrations)) - - _, err = bundle.tx.Exec(ctx, "SELECT * FROM test_table") - require.NoError(t, err) - } - }) - - t.Run("UpWithMaxSteps", func(t *testing.T) { - t.Parallel() - - migrator, bundle := setup(t) - - res, err := migrator.Up(ctx, bundle.tx, &MigrateOptions{MaxSteps: ptrutil.Ptr(1)}) - require.NoError(t, err) - require.Equal(t, []int64{riverMigrationsMaxVersion + 1}, res.Versions) - - migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) - require.NoError(t, err) - require.Equal(t, seqOneTo(riverMigrationsMaxVersion+1), migrationVersions(migrations)) - - // Column `name` is only added in the second test version. - err = dbExecError(ctx, bundle.tx, "SELECT name FROM test_table") - require.Error(t, err) - - var pgErr *pgconn.PgError - require.ErrorAs(t, err, &pgErr) - require.Equal(t, pgerrcode.UndefinedColumn, pgErr.Code) - }) - - t.Run("UpWithMaxStepsZero", func(t *testing.T) { - t.Parallel() - - migrator, bundle := setup(t) - - res, err := migrator.Up(ctx, bundle.tx, &MigrateOptions{MaxSteps: ptrutil.Ptr(0)}) - require.NoError(t, err) - require.Equal(t, []int64{}, res.Versions) - }) -} - -// A command returning an error aborts the transaction. This is a shortcut to -// execute a command in a subtransaction so that we can verify an error, but -// continue to use the original transaction. -func dbExecError(ctx context.Context, executor dbutil.Executor, sql string) error { - return dbutil.WithTx(ctx, executor, func(ctx context.Context, tx pgx.Tx) error { - _, err := tx.Exec(ctx, sql) - return err - }) -} - -func migrationVersions(migrations []*dbsqlc.RiverMigration) []int64 { - return sliceutil.Map(migrations, func(r *dbsqlc.RiverMigration) int64 { return r.Version }) -} - -func seqOneTo(max int64) []int64 { - seq := make([]int64, max) - - for i := 0; i < int(max); i++ { - seq[i] = int64(i + 1) - } - - return seq -} - -func seqToOne(max int64) []int64 { - seq := seqOneTo(max) - slices.Reverse(seq) - return seq -} diff --git a/internal/dbmigrate/main_test.go b/rivermigrate/main_test.go similarity index 85% rename from internal/dbmigrate/main_test.go rename to rivermigrate/main_test.go index eadb217d..d13d74b3 100644 --- a/internal/dbmigrate/main_test.go +++ b/rivermigrate/main_test.go @@ -1,4 +1,4 @@ -package dbmigrate +package rivermigrate_test import ( "testing" diff --git a/internal/dbmigrate/001_create_river_migration.down.sql b/rivermigrate/migration/001_create_river_migration.down.sql similarity index 100% rename from internal/dbmigrate/001_create_river_migration.down.sql rename to rivermigrate/migration/001_create_river_migration.down.sql diff --git a/internal/dbmigrate/001_create_river_migration.up.sql b/rivermigrate/migration/001_create_river_migration.up.sql similarity index 100% rename from internal/dbmigrate/001_create_river_migration.up.sql rename to rivermigrate/migration/001_create_river_migration.up.sql diff --git a/internal/dbmigrate/002_initial_schema.down.sql b/rivermigrate/migration/002_initial_schema.down.sql similarity index 100% rename from internal/dbmigrate/002_initial_schema.down.sql rename to rivermigrate/migration/002_initial_schema.down.sql diff --git a/internal/dbmigrate/002_initial_schema.up.sql b/rivermigrate/migration/002_initial_schema.up.sql similarity index 100% rename from internal/dbmigrate/002_initial_schema.up.sql rename to rivermigrate/migration/002_initial_schema.up.sql diff --git a/internal/dbmigrate/003_river_job_tags_non_null.down.sql b/rivermigrate/migration/003_river_job_tags_non_null.down.sql similarity index 100% rename from internal/dbmigrate/003_river_job_tags_non_null.down.sql rename to rivermigrate/migration/003_river_job_tags_non_null.down.sql diff --git a/internal/dbmigrate/003_river_job_tags_non_null.up.sql b/rivermigrate/migration/003_river_job_tags_non_null.up.sql similarity index 100% rename from internal/dbmigrate/003_river_job_tags_non_null.up.sql rename to rivermigrate/migration/003_river_job_tags_non_null.up.sql diff --git a/rivermigrate/river_migrate.go b/rivermigrate/river_migrate.go new file mode 100644 index 00000000..79bf4eb3 --- /dev/null +++ b/rivermigrate/river_migrate.go @@ -0,0 +1,440 @@ +// Package rivermigrate provides a Go API for running migrations as alternative +// to migrating via the bundled CLI. +package rivermigrate + +import ( + "context" + _ "embed" + "errors" + "fmt" + "log/slog" + "maps" + "os" + "slices" + "strings" + "time" + + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + + "github.com/riverqueue/river/internal/baseservice" + "github.com/riverqueue/river/internal/dbsqlc" + "github.com/riverqueue/river/internal/util/dbutil" + "github.com/riverqueue/river/internal/util/maputil" + "github.com/riverqueue/river/internal/util/sliceutil" + "github.com/riverqueue/river/riverdriver" +) + +var ( + //go:embed migration/001_create_river_migration.down.sql + sql001CreateRiverMigrationDown string + + //go:embed migration/001_create_river_migration.up.sql + sql001CreateRiverMigrationUp string + + //go:embed migration/002_initial_schema.down.sql + sql002InitialSchemaDown string + + //go:embed migration/002_initial_schema.up.sql + sql002InitialSchemaUp string + + //go:embed migration/003_river_job_tags_non_null.down.sql + sql003RiverJobTagsNonNullDown string + + //go:embed migration/003_river_job_tags_non_null.up.sql + sql003RiverJobTagsNonNullUp string +) + +type migrationBundle struct { + Version int + Up string + Down string +} + +//nolint:gochecknoglobals +var ( + riverMigrations = []*migrationBundle{ + {Version: 1, Up: sql001CreateRiverMigrationUp, Down: sql001CreateRiverMigrationDown}, + {Version: 2, Up: sql002InitialSchemaUp, Down: sql002InitialSchemaDown}, + {Version: 3, Up: sql003RiverJobTagsNonNullUp, Down: sql003RiverJobTagsNonNullDown}, + } + + riverMigrationsMap = validateAndInit(riverMigrations) +) + +// Config contains configuration for Migrator. +type Config struct { + // Logger is the structured logger to use for logging purposes. If none is + // specified, logs will be emitted to STDOUT with messages at warn level + // or higher. + Logger *slog.Logger +} + +// Migrator is a database migration tool for River which can run up or down +// migrations in order to establish the schema that the queue needs to run. +type Migrator[TTx any] struct { + baseservice.BaseService + + driver riverdriver.Driver[TTx] + migrations map[int]*migrationBundle + queries *dbsqlc.Queries +} + +// New returns a new migrator with the given database driver and configuration. +// The config parameter may be omitted as nil. +// +// Currently only one driver is supported, which is Pgx v5. See package +// riverpgxv5. +// +// The function takes a generic parameter TTx representing a transaction type, +// but it can be omitted because it'll generally always be inferred from the +// driver. For example: +// +// import "github.com/riverqueue/river/riverdriver/riverpgxv5" +// import "github.com/riverqueue/rivermigrate" +// +// ... +// +// dbPool, err := pgxpool.New(ctx, os.Getenv("DATABASE_URL")) +// if err != nil { +// // handle error +// } +// defer dbPool.Close() +// +// migrator, err := rivermigrate.New(riverpgxv5.New(dbPool), nil) +// if err != nil { +// // handle error +// } +func New[TTx any](driver riverdriver.Driver[TTx], config *Config) *Migrator[TTx] { + if config == nil { + config = &Config{} + } + + logger := config.Logger + if logger == nil { + logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{ + Level: slog.LevelWarn, + })) + } + + archetype := &baseservice.Archetype{ + Logger: logger, + TimeNowUTC: func() time.Time { return time.Now().UTC() }, + } + + return baseservice.Init(archetype, &Migrator[TTx]{ + driver: driver, + migrations: riverMigrationsMap, + queries: dbsqlc.New(), + }) +} + +// MigrateOpts are options for a migrate operation. +type MigrateOpts struct { + // MaxSteps is the maximum number of migrations to apply either up or down. + // Leave zero for an unlimited number. Set to -1 to apply no migrations (for + // testing/checking purposes). + MaxSteps int + + // TargetVersion is a specific migration version to apply migrations to. The + // version must exist and it must be in the possible list of migrations to + // apply. e.g. If requesting an up migration with version 3, version 3 not + // already be applied. + // + // When applying migrations up, migrations are applied including the target + // version, so when starting at version 0 and requesting version 3, versions + // 1, 2, and 3 would be applied. When applying migrations down, down + // migrations are applied excluding the target version, so when starting at + // version 5 an requesting version 3, down migrations for versions 5 and 4 + // would be applied, leaving the final schema at version 3. + TargetVersion int +} + +// MigrateResult is the result of a migrate operation. +type MigrateResult struct { + // Versions are migration versions that were added (for up migrations) or + // removed (for down migrations) for this run. + Versions []MigrateVersion +} + +// MigrateVersion is the result for a single applied migration. +type MigrateVersion struct { + // Version is the version of the migration applied. + Version int +} + +func migrateVersionToInt(version MigrateVersion) int { return version.Version } +func migrateVersionToInt64(version MigrateVersion) int64 { return int64(version.Version) } + +type Direction string + +const ( + DirectionDown Direction = "down" + DirectionUp Direction = "up" +) + +// Migrate migrates the database in the given direction (up or down). The opts +// parameter may be omitted for convenience. +// +// By default, applies all outstanding migrations possible in either direction. +// When migrating up all outstanding migrations are applied, and when migrating +// down all existing migrations are unapplied. +// +// When migrating down, use with caution. MigrateOpts.MaxSteps should be set to +// 1 to only migrate down one step. +// +// res, err := migrator.Migrate(ctx, rivermigrate.DirectionUp, nil) +// if err != nil { +// // handle error +// } +func (m *Migrator[TTx]) Migrate(ctx context.Context, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { + return dbutil.WithTxV(ctx, m.driver.GetDBPool(), func(ctx context.Context, tx pgx.Tx) (*MigrateResult, error) { + switch direction { + case DirectionDown: + return m.migrateDownTx(ctx, tx, direction, opts) + case DirectionUp: + return m.migrateUpTx(ctx, tx, direction, opts) + } + + panic("invalid direction: " + direction) + }) +} + +// Migrate migrates the database in the given direction (up or down). The opts +// parameter may be omitted for convenience. +// +// By default, applies all outstanding migrations possible in either direction. +// When migrating up all outstanding migrations are applied, and when migrating +// down all existing migrations are unapplied. +// +// When migrating down, use with caution. MigrateOpts.MaxSteps should be set to +// 1 to only migrate down one step. +// +// res, err := migrator.MigrateTx(ctx, tx, rivermigrate.DirectionUp, nil) +// if err != nil { +// // handle error +// } +// +// This variant lets a caller run migrations within a transaction. Postgres DDL +// is transactional, so migration changes aren't visible until the transaction +// commits, and are rolled back if the transaction rolls back. +func (m *Migrator[TTx]) MigrateTx(ctx context.Context, tx TTx, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { + switch direction { + case DirectionDown: + return m.migrateDownTx(ctx, m.driver.UnwrapTx(tx), direction, opts) + case DirectionUp: + return m.migrateUpTx(ctx, m.driver.UnwrapTx(tx), direction, opts) + } + + panic("invalid direction: " + direction) +} + +// migrateDownTx runs down migrations. +func (m *Migrator[TTx]) migrateDownTx(ctx context.Context, tx pgx.Tx, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { + existingMigrations, err := m.existingMigrations(ctx, tx) + if err != nil { + return nil, err + } + existingMigrationsMap := sliceutil.KeyBy(existingMigrations, + func(m *dbsqlc.RiverMigration) (int, struct{}) { return int(m.Version), struct{}{} }) + + targetMigrations := maps.Clone(m.migrations) + for version := range targetMigrations { + if _, ok := existingMigrationsMap[version]; !ok { + delete(targetMigrations, version) + } + } + + sortedTargetMigrations := maputil.Values(targetMigrations) + slices.SortFunc(sortedTargetMigrations, func(a, b *migrationBundle) int { return b.Version - a.Version }) // reverse order + + res, err := m.applyMigrations(ctx, tx, direction, opts, sortedTargetMigrations) + if err != nil { + return nil, err + } + + // If we did no work, leave early. This allows a zero-migrated database + // that's being no-op downmigrated again to succeed because otherwise + // the delete below would cause it to error. + if len(res.Versions) < 1 { + return res, nil + } + + // Migration version 1 is special-cased because if it was downmigrated + // it means the `river_migration` table is no longer present so there's + // nothing to delete out of. + if slices.ContainsFunc(res.Versions, func(v MigrateVersion) bool { return v.Version == 1 }) { + return res, nil + } + + if _, err := m.queries.RiverMigrationDeleteByVersionMany(ctx, tx, sliceutil.Map(res.Versions, migrateVersionToInt64)); err != nil { + return nil, fmt.Errorf("error deleting migration rows for versions %+v: %w", res.Versions, err) + } + + return res, nil +} + +// migrateUpTx runs up migrations. +func (m *Migrator[TTx]) migrateUpTx(ctx context.Context, tx pgx.Tx, direction Direction, opts *MigrateOpts) (*MigrateResult, error) { + existingMigrations, err := m.existingMigrations(ctx, tx) + if err != nil { + return nil, err + } + + targetMigrations := maps.Clone(m.migrations) + for _, migrateRow := range existingMigrations { + delete(targetMigrations, int(migrateRow.Version)) + } + + sortedTargetMigrations := maputil.Values(targetMigrations) + slices.SortFunc(sortedTargetMigrations, func(a, b *migrationBundle) int { return a.Version - b.Version }) + + res, err := m.applyMigrations(ctx, tx, direction, opts, sortedTargetMigrations) + if err != nil { + return nil, err + } + + if _, err := m.queries.RiverMigrationInsertMany(ctx, tx, sliceutil.Map(res.Versions, migrateVersionToInt64)); err != nil { + return nil, fmt.Errorf("error inserting migration rows for versions %+v: %w", res.Versions, err) + } + + return res, nil +} + +// Common code shared between the up and down migration directions that walks +// through each target migration and applies it, logging appropriately. +func (m *Migrator[TTx]) applyMigrations(ctx context.Context, tx pgx.Tx, direction Direction, opts *MigrateOpts, sortedTargetMigrations []*migrationBundle) (*MigrateResult, error) { + if opts == nil { + opts = &MigrateOpts{} + } + + switch { + case opts.MaxSteps < 0: + sortedTargetMigrations = []*migrationBundle{} + case opts.MaxSteps > 0: + sortedTargetMigrations = sortedTargetMigrations[0:min(opts.MaxSteps, len(sortedTargetMigrations))] + } + + if opts.TargetVersion > 0 { + if _, ok := m.migrations[opts.TargetVersion]; !ok { + return nil, fmt.Errorf("version %d is not a valid River migration version", opts.TargetVersion) + } + + targetIndex := slices.IndexFunc(sortedTargetMigrations, func(b *migrationBundle) bool { return b.Version == opts.TargetVersion }) + if targetIndex == -1 { + return nil, fmt.Errorf("version %d is not in target list of valid migrations to apply", opts.TargetVersion) + } + + // Replace target list with list up to target index. Migrations are + // sorted according to the direction we're migrating in, so when down + // migration, the list is already reversed, so this will truncate it so + // it's the most current migration down to the target. + sortedTargetMigrations = sortedTargetMigrations[0 : targetIndex+1] + + if direction == DirectionDown && len(sortedTargetMigrations) > 0 { + sortedTargetMigrations = sortedTargetMigrations[0 : len(sortedTargetMigrations)-1] + } + } + + res := &MigrateResult{Versions: make([]MigrateVersion, 0, len(sortedTargetMigrations))} + + // Short circuit early if there's nothing to do. + if len(sortedTargetMigrations) < 1 { + m.Logger.InfoContext(ctx, m.Name+": No migrations to apply") + return res, nil + } + + for _, versionBundle := range sortedTargetMigrations { + sql := versionBundle.Up + if direction == DirectionDown { + sql = versionBundle.Down + } + + m.Logger.InfoContext(ctx, fmt.Sprintf(m.Name+": Applying migration %03d [%s]", versionBundle.Version, strings.ToUpper(string(direction))), + slog.String("direction", string(direction)), + slog.Int("version", versionBundle.Version), + ) + + _, err := tx.Exec(ctx, sql) + if err != nil { + return nil, fmt.Errorf("error applying version %03d [%s]: %w", + versionBundle.Version, strings.ToUpper(string(direction)), err) + } + + res.Versions = append(res.Versions, MigrateVersion{Version: versionBundle.Version}) + } + + // Only prints if more steps than available were requested. + if opts.MaxSteps > 0 && len(res.Versions) < opts.MaxSteps { + m.Logger.InfoContext(ctx, m.Name+": No more migrations to apply") + } + + return res, nil +} + +// Get existing migrations that've already been run in the database. This is +// encapsulated to run a check in a subtransaction and the handle the case of +// the `river_migration` table not existing yet. (The subtransaction is needed +// because otherwise the existing transaction would become aborted on an +// unsuccessful `river_migration` check.) +func (m *Migrator[TTx]) existingMigrations(ctx context.Context, tx pgx.Tx) ([]*dbsqlc.RiverMigration, error) { + // We start another inner transaction here because in case this is the first + // ever migration run, the transaction may become aborted if `river_migration` + // doesn't exist, a condition which we must handle gracefully. + migrations, err := dbutil.WithTxV(ctx, tx, func(ctx context.Context, tx pgx.Tx) ([]*dbsqlc.RiverMigration, error) { + migrations, err := m.queries.RiverMigrationGetAll(ctx, tx) + if err != nil { + return nil, fmt.Errorf("error getting current migrate rows: %w", err) + } + return migrations, nil + }) + if err != nil { + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + if pgErr.Code == pgerrcode.UndefinedTable && strings.Contains(pgErr.Message, "river_migration") { + return nil, nil + } + } + + return nil, err + } + + return migrations, nil +} + +// Validates and fully initializes a set of migrations to reduce the probability +// of configuration problems as new migrations are introduced. e.g. Checks for +// missing fields or accidentally duplicated version numbers from copy/pasta +// problems. +func validateAndInit(versions []*migrationBundle) map[int]*migrationBundle { + lastVersion := 0 + migrations := make(map[int]*migrationBundle, len(versions)) + + for _, versionBundle := range versions { + if versionBundle.Down == "" { + panic(fmt.Sprintf("version bundle should specify Down: %+v", versionBundle)) + } + if versionBundle.Up == "" { + panic(fmt.Sprintf("version bundle should specify Up: %+v", versionBundle)) + } + if versionBundle.Version == 0 { + panic(fmt.Sprintf("version bundle should specify Version: %+v", versionBundle)) + } + + if _, ok := migrations[versionBundle.Version]; ok { + panic(fmt.Sprintf("duplicate version: %03d", versionBundle.Version)) + } + if versionBundle.Version <= lastVersion { + panic(fmt.Sprintf("versions should be ascending; current: %03d, last: %03d", versionBundle.Version, lastVersion)) + } + if versionBundle.Version > lastVersion+1 { + panic(fmt.Sprintf("versions shouldn't skip a sequence number; current: %03d, last: %03d", versionBundle.Version, lastVersion)) + } + + lastVersion = versionBundle.Version + migrations[versionBundle.Version] = versionBundle + } + + return migrations +} diff --git a/rivermigrate/river_migrate_test.go b/rivermigrate/river_migrate_test.go new file mode 100644 index 00000000..65c2aa9d --- /dev/null +++ b/rivermigrate/river_migrate_test.go @@ -0,0 +1,354 @@ +package rivermigrate + +import ( + "context" + "slices" + "testing" + + "github.com/jackc/pgerrcode" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/require" + + "github.com/riverqueue/river/internal/dbsqlc" + "github.com/riverqueue/river/internal/riverinternaltest" + "github.com/riverqueue/river/internal/util/dbutil" + "github.com/riverqueue/river/internal/util/sliceutil" + "github.com/riverqueue/river/riverdriver/riverpgxv5" +) + +//nolint:gochecknoglobals +var ( + // We base our test migrations on the actual line of migrations, so get + // their maximum version number which we'll use to define test version + // numbers so that the tests don't break anytime we add a new one. + riverMigrationsMaxVersion = riverMigrations[len(riverMigrations)-1].Version + + testVersions = []*migrationBundle{ + { + Version: riverMigrationsMaxVersion + 1, + Up: "CREATE TABLE test_table(id bigserial PRIMARY KEY);", + Down: "DROP TABLE test_table;", + }, + { + Version: riverMigrationsMaxVersion + 2, + Up: "ALTER TABLE test_table ADD COLUMN name varchar(200); CREATE INDEX idx_test_table_name ON test_table(name);", + Down: "DROP INDEX idx_test_table_name; ALTER TABLE test_table DROP COLUMN name;", + }, + } + + riverMigrationsWithtestVersionsMap = validateAndInit(append(riverMigrations, testVersions...)) + riverMigrationsWithTestVersionsMaxVersion = riverMigrationsMaxVersion + len(testVersions) +) + +func TestMigrator(t *testing.T) { + t.Parallel() + + var ( + ctx = context.Background() + queries = dbsqlc.New() + ) + + type testBundle struct { + tx pgx.Tx + } + + setup := func(t *testing.T) (*Migrator[pgx.Tx], *testBundle) { + t.Helper() + + // The test suite largely works fine with test transactions, bue due to + // the invasive nature of changing schemas, it's quite easy to have test + // transactions deadlock with each other as they run in parallel. Here + // we use test DBs instead of test transactions, but this could be + // changed to test transactions as long as test cases were made to run + // non-parallel. + testDB := riverinternaltest.TestDB(ctx, t) + + // Despite being in an isolated database, we still start a transaction + // because we don't want schema changes we make to persist. + tx, err := testDB.Begin(ctx) + require.NoError(t, err) + t.Cleanup(func() { _ = tx.Rollback(ctx) }) + + bundle := &testBundle{ + tx: tx, + } + + migrator := New(riverpgxv5.New(testDB), nil) + migrator.migrations = riverMigrationsWithtestVersionsMap + + return migrator, bundle + } + + t.Run("MigrateDown", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + // Run an initial time + { + res, err := migrator.MigrateTx(ctx, bundle.tx, DirectionDown, &MigrateOpts{}) + require.NoError(t, err) + require.Equal(t, seqToOne(3), sliceutil.Map(res.Versions, migrateVersionToInt)) + + err = dbExecError(ctx, bundle.tx, "SELECT * FROM river_migration") + require.Error(t, err) + } + + // Run once more to verify idempotency + { + res, err := migrator.MigrateTx(ctx, bundle.tx, DirectionDown, &MigrateOpts{}) + require.NoError(t, err) + require.Equal(t, []int{}, sliceutil.Map(res.Versions, migrateVersionToInt)) + + err = dbExecError(ctx, bundle.tx, "SELECT * FROM river_migration") + require.Error(t, err) + } + }) + + t.Run("MigrateDownAfterUp", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + _, err := migrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{}) + require.NoError(t, err) + + res, err := migrator.MigrateTx(ctx, bundle.tx, DirectionDown, &MigrateOpts{}) + require.NoError(t, err) + require.Equal(t, seqToOne(riverMigrationsWithTestVersionsMaxVersion), sliceutil.Map(res.Versions, migrateVersionToInt)) + + err = dbExecError(ctx, bundle.tx, "SELECT * FROM river_migration") + require.Error(t, err) + }) + + t.Run("MigrateDownWithMaxSteps", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + _, err := migrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{}) + require.NoError(t, err) + + res, err := migrator.MigrateTx(ctx, bundle.tx, DirectionDown, &MigrateOpts{MaxSteps: 1}) + require.NoError(t, err) + require.Equal(t, []int{riverMigrationsWithTestVersionsMaxVersion}, + sliceutil.Map(res.Versions, migrateVersionToInt)) + + migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + require.NoError(t, err) + require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion-1), + sliceutil.Map(migrations, riverMigrationToInt)) + + err = dbExecError(ctx, bundle.tx, "SELECT name FROM test_table") + require.Error(t, err) + }) + + t.Run("MigrateDownWithPool", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + // We don't actually migrate anything (max steps = -1) because doing so + // would mess with the test database, but this still runs most code to + // check that the function generally works. + res, err := migrator.Migrate(ctx, DirectionDown, &MigrateOpts{MaxSteps: -1}) + require.NoError(t, err) + require.Equal(t, []int{}, sliceutil.Map(res.Versions, migrateVersionToInt)) + + migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + require.NoError(t, err) + require.Equal(t, seqOneTo(3), + sliceutil.Map(migrations, riverMigrationToInt)) + }) + + t.Run("MigrateDownWithTargetVersion", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + _, err := migrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{}) + require.NoError(t, err) + + res, err := migrator.MigrateTx(ctx, bundle.tx, DirectionDown, &MigrateOpts{TargetVersion: 3}) + require.NoError(t, err) + require.Equal(t, []int{5, 4}, + sliceutil.Map(res.Versions, migrateVersionToInt)) + + migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + require.NoError(t, err) + require.Equal(t, seqOneTo(3), + sliceutil.Map(migrations, riverMigrationToInt)) + + err = dbExecError(ctx, bundle.tx, "SELECT name FROM test_table") + require.Error(t, err) + }) + + t.Run("MigrateDownWithTargetVersionInvalid", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + // migration doesn't exist + { + _, err := migrator.MigrateTx(ctx, bundle.tx, DirectionDown, &MigrateOpts{TargetVersion: 77}) + require.EqualError(t, err, "version 77 is not a valid River migration version") + } + + // migration exists but not one that's applied + { + _, err := migrator.MigrateTx(ctx, bundle.tx, DirectionDown, &MigrateOpts{TargetVersion: 4}) + require.EqualError(t, err, "version 4 is not in target list of valid migrations to apply") + } + }) + + t.Run("MigrateNilOpts", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + res, err := migrator.MigrateTx(ctx, bundle.tx, DirectionUp, nil) + require.NoError(t, err) + require.Equal(t, []int{4, 5}, sliceutil.Map(res.Versions, migrateVersionToInt)) + }) + + t.Run("MigrateUp", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + // Run an initial time + { + res, err := migrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{}) + require.NoError(t, err) + require.Equal(t, []int{riverMigrationsWithTestVersionsMaxVersion - 1, riverMigrationsWithTestVersionsMaxVersion}, + sliceutil.Map(res.Versions, migrateVersionToInt)) + + migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + require.NoError(t, err) + require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion), + sliceutil.Map(migrations, riverMigrationToInt)) + + _, err = bundle.tx.Exec(ctx, "SELECT * FROM test_table") + require.NoError(t, err) + } + + // Run once more to verify idempotency + { + res, err := migrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{}) + require.NoError(t, err) + require.Equal(t, []int{}, sliceutil.Map(res.Versions, migrateVersionToInt)) + + migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + require.NoError(t, err) + require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion), + sliceutil.Map(migrations, riverMigrationToInt)) + + _, err = bundle.tx.Exec(ctx, "SELECT * FROM test_table") + require.NoError(t, err) + } + }) + + t.Run("MigrateUpWithMaxSteps", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + res, err := migrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{MaxSteps: 1}) + require.NoError(t, err) + require.Equal(t, []int{riverMigrationsWithTestVersionsMaxVersion - 1}, + sliceutil.Map(res.Versions, migrateVersionToInt)) + + migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + require.NoError(t, err) + require.Equal(t, seqOneTo(riverMigrationsWithTestVersionsMaxVersion-1), + sliceutil.Map(migrations, riverMigrationToInt)) + + // Column `name` is only added in the second test version. + err = dbExecError(ctx, bundle.tx, "SELECT name FROM test_table") + require.Error(t, err) + + var pgErr *pgconn.PgError + require.ErrorAs(t, err, &pgErr) + require.Equal(t, pgerrcode.UndefinedColumn, pgErr.Code) + }) + + t.Run("MigrateUpWithPool", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + // We don't actually migrate anything (max steps = -1) because doing so + // would mess with the test database, but this still runs most code to + // check that the function generally works. + res, err := migrator.Migrate(ctx, DirectionUp, &MigrateOpts{MaxSteps: -1}) + require.NoError(t, err) + require.Equal(t, []int{}, sliceutil.Map(res.Versions, migrateVersionToInt)) + + migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + require.NoError(t, err) + require.Equal(t, seqOneTo(3), + sliceutil.Map(migrations, riverMigrationToInt)) + }) + + t.Run("MigrateUpWithTargetVersion", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + res, err := migrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{TargetVersion: 5}) + require.NoError(t, err) + require.Equal(t, []int{4, 5}, + sliceutil.Map(res.Versions, migrateVersionToInt)) + + migrations, err := queries.RiverMigrationGetAll(ctx, bundle.tx) + require.NoError(t, err) + require.Equal(t, seqOneTo(5), sliceutil.Map(migrations, riverMigrationToInt)) + }) + + t.Run("MigrateUpWithTargetVersionInvalid", func(t *testing.T) { + t.Parallel() + + migrator, bundle := setup(t) + + // migration doesn't exist + { + _, err := migrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{TargetVersion: 77}) + require.EqualError(t, err, "version 77 is not a valid River migration version") + } + + // migration exists but already applied + { + _, err := migrator.MigrateTx(ctx, bundle.tx, DirectionUp, &MigrateOpts{TargetVersion: 3}) + require.EqualError(t, err, "version 3 is not in target list of valid migrations to apply") + } + }) +} + +// A command returning an error aborts the transaction. This is a shortcut to +// execute a command in a subtransaction so that we can verify an error, but +// continue to use the original transaction. +func dbExecError(ctx context.Context, executor dbutil.Executor, sql string) error { + return dbutil.WithTx(ctx, executor, func(ctx context.Context, tx pgx.Tx) error { + _, err := tx.Exec(ctx, sql) + return err + }) +} + +func riverMigrationToInt(r *dbsqlc.RiverMigration) int { return int(r.Version) } + +func seqOneTo(max int) []int { + seq := make([]int, max) + + for i := 0; i < max; i++ { + seq[i] = i + 1 + } + + return seq +} + +func seqToOne(max int) []int { + seq := seqOneTo(max) + slices.Reverse(seq) + return seq +}