diff --git a/internal/testing/integration/postgres_locking_test.go b/internal/testing/integration/postgres_locking_test.go index bfced40e0..465bac9f4 100644 --- a/internal/testing/integration/postgres_locking_test.go +++ b/internal/testing/integration/postgres_locking_test.go @@ -410,30 +410,37 @@ func TestPostgresProviderLocking(t *testing.T) { }) } -func TestPostgresHasPending(t *testing.T) { +func TestPostgresPending(t *testing.T) { t.Parallel() if testing.Short() { t.Skip("skipping test in short mode.") } + const testDir = "testdata/migrations/postgres" db, cleanup, err := testdb.NewPostgres() require.NoError(t, err) t.Cleanup(cleanup) + files, err := os.ReadDir(testDir) + require.NoError(t, err) + workers := 15 - run := func(want bool) { + run := func(t *testing.T, want bool, wantCurrent, wantTarget int) { + t.Helper() var g errgroup.Group boolCh := make(chan bool, workers) for i := 0; i < workers; i++ { g.Go(func() error { - p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS("testdata/migrations/postgres")) + p, err := goose.NewProvider(goose.DialectPostgres, db, os.DirFS(testDir)) check.NoError(t, err) hasPending, err := p.HasPending(context.Background()) - if err != nil { - return err - } + check.NoError(t, err) boolCh <- hasPending + current, target, err := p.CheckPending(context.Background()) + check.NoError(t, err) + check.Number(t, current, int64(wantCurrent)) + check.Number(t, target, int64(wantTarget)) return nil }) @@ -446,7 +453,7 @@ func TestPostgresHasPending(t *testing.T) { } } t.Run("concurrent_has_pending", func(t *testing.T) { - run(true) + run(t, true, 0, len(files)) }) // apply all migrations @@ -456,12 +463,12 @@ func TestPostgresHasPending(t *testing.T) { check.NoError(t, err) t.Run("concurrent_no_pending", func(t *testing.T) { - run(false) + run(t, false, len(files), len(files)) }) // Add a new migration file - last := p.ListSources()[len(p.ListSources())-1] - newVersion := fmt.Sprintf("%d_new_migration.sql", last.Version+1) + lastVersion := len(files) + newVersion := fmt.Sprintf("%d_new_migration.sql", lastVersion+1) fsys := fstest.MapFS{ newVersion: &fstest.MapFile{Data: []byte(` -- +goose Up @@ -476,7 +483,7 @@ SELECT pg_sleep_for('4 seconds'); check.NoError(t, err) check.Number(t, len(newProvider.ListSources()), 1) oldProvider := p - check.Number(t, len(oldProvider.ListSources()), 6) + check.Number(t, len(oldProvider.ListSources()), len(files)) var g errgroup.Group g.Go(func() error { @@ -485,6 +492,12 @@ SELECT pg_sleep_for('4 seconds'); return err } check.Bool(t, hasPending, true) + current, target, err := newProvider.CheckPending(context.Background()) + if err != nil { + return err + } + check.Number(t, current, lastVersion) + check.Number(t, target, lastVersion+1) return nil }) g.Go(func() error { @@ -493,6 +506,12 @@ SELECT pg_sleep_for('4 seconds'); return err } check.Bool(t, hasPending, false) + current, target, err := oldProvider.CheckPending(context.Background()) + if err != nil { + return err + } + check.Number(t, current, lastVersion) + check.Number(t, target, lastVersion) return nil }) check.NoError(t, g.Wait()) @@ -512,16 +531,24 @@ SELECT pg_sleep_for('4 seconds'); hasPending, err := oldProvider.HasPending(context.Background()) check.NoError(t, err) check.Bool(t, hasPending, false) + current, target, err := oldProvider.CheckPending(context.Background()) + check.NoError(t, err) + check.Number(t, current, lastVersion) + check.Number(t, target, lastVersion) // Wait for the long running migration to finish check.NoError(t, g.Wait()) // Check that the new migration was applied hasPending, err = newProvider.HasPending(context.Background()) check.NoError(t, err) check.Bool(t, hasPending, false) + current, target, err = newProvider.CheckPending(context.Background()) + check.NoError(t, err) + check.Number(t, current, lastVersion+1) + check.Number(t, target, lastVersion+1) // The max version should be the new migration currentVersion, err := newProvider.GetDBVersion(context.Background()) check.NoError(t, err) - check.Number(t, currentVersion, last.Version+1) + check.Number(t, currentVersion, lastVersion+1) } func existsPgLock(ctx context.Context, db *sql.DB, lockID int64) (bool, error) { diff --git a/provider.go b/provider.go index 65674d566..5099789a1 100644 --- a/provider.go +++ b/provider.go @@ -162,6 +162,15 @@ func (p *Provider) HasPending(ctx context.Context) (bool, error) { return p.hasPending(ctx) } +// CheckPending returns the current database version and the target version to migrate to. If there +// are no pending migrations, the target version will be the same as the current version. +// +// Note, this method will not use a SessionLocker if one is configured. This allows callers to check +// for pending migrations without blocking or being blocked by other operations. +func (p *Provider) CheckPending(ctx context.Context) (current, target int64, err error) { + return p.checkPending(ctx) +} + // GetDBVersion returns the highest version recorded in the database, regardless of the order in // which migrations were applied. For example, if migrations were applied out of order (1,4,2,3), // this method returns 4. If no migrations have been applied, it returns 0. @@ -465,6 +474,41 @@ func (p *Provider) apply( return p.runMigrations(ctx, conn, []*Migration{m}, d, true) } +func (p *Provider) checkPending(ctx context.Context) (current, target int64, retErr error) { + conn, cleanup, err := p.initialize(ctx, false) + if err != nil { + return -1, -1, fmt.Errorf("failed to initialize: %w", err) + } + defer func() { + retErr = multierr.Append(retErr, cleanup()) + }() + + // If versioning is disabled, we always have pending migrations and the target version is the + // last migration. + if p.cfg.disableVersioning { + return -1, p.migrations[len(p.migrations)-1].Version, nil + } + // optimize(mf): we should only fetch the max version from the database, no need to fetch all + // migrations only to get the max version when we're not using out-of-order migrations. + res, err := p.store.ListMigrations(ctx, conn) + if err != nil { + return -1, -1, err + } + dbVersions := make([]int64, 0, len(res)) + for _, m := range res { + dbVersions = append(dbVersions, m.Version) + } + sort.Slice(dbVersions, func(i, j int) bool { + return dbVersions[i] < dbVersions[j] + }) + if len(dbVersions) == 0 { + return -1, -1, errMissingZeroVersion + } else { + current = dbVersions[len(dbVersions)-1] + } + return current, p.migrations[len(p.migrations)-1].Version, nil +} + func (p *Provider) hasPending(ctx context.Context) (_ bool, retErr error) { conn, cleanup, err := p.initialize(ctx, false) if err != nil { diff --git a/provider_run_test.go b/provider_run_test.go index 59ec62323..5a7dd63ed 100644 --- a/provider_run_test.go +++ b/provider_run_test.go @@ -5,7 +5,6 @@ import ( "database/sql" "errors" "fmt" - "io/fs" "math" "math/rand" "os" @@ -775,11 +774,12 @@ func TestProviderApply(t *testing.T) { check.Bool(t, errors.Is(err, goose.ErrNotApplied), true) } -func TestHasPending(t *testing.T) { +func TestPending(t *testing.T) { t.Parallel() t.Run("allow_out_of_order", func(t *testing.T) { ctx := context.Background() - p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), newFsys(), + fsys := newFsys() + p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), fsys, goose.WithAllowOutofOrder(true), ) check.NoError(t, err) @@ -791,6 +791,10 @@ func TestHasPending(t *testing.T) { hasPending, err := p.HasPending(ctx) check.NoError(t, err) check.Bool(t, hasPending, true) + current, target, err := p.CheckPending(ctx) + check.NoError(t, err) + check.Number(t, current, 3) + check.Number(t, target, len(fsys)) // Apply the missing migrations. _, err = p.Up(ctx) check.NoError(t, err) @@ -798,10 +802,14 @@ func TestHasPending(t *testing.T) { hasPending, err = p.HasPending(ctx) check.NoError(t, err) check.Bool(t, hasPending, false) + current, target, err = p.CheckPending(ctx) + check.NoError(t, err) + check.Number(t, current, target) }) t.Run("disallow_out_of_order", func(t *testing.T) { ctx := context.Background() - p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), newFsys(), + fsys := newFsys() + p, err := goose.NewProvider(goose.DialectSQLite3, newDB(t), fsys, goose.WithAllowOutofOrder(false), ) check.NoError(t, err) @@ -813,12 +821,19 @@ func TestHasPending(t *testing.T) { hasPending, err := p.HasPending(ctx) check.NoError(t, err) check.Bool(t, hasPending, true) + current, target, err := p.CheckPending(ctx) + check.NoError(t, err) + check.Number(t, current, 2) + check.Number(t, target, len(fsys)) _, err = p.Up(ctx) check.NoError(t, err) // All migrations have been applied. hasPending, err = p.HasPending(ctx) check.NoError(t, err) check.Bool(t, hasPending, false) + current, target, err = p.CheckPending(ctx) + check.NoError(t, err) + check.Number(t, current, target) }) } @@ -1089,7 +1104,7 @@ func newMapFile(data string) *fstest.MapFile { } } -func newFsys() fs.FS { +func newFsys() fstest.MapFS { return fstest.MapFS{ "00001_users_table.sql": newMapFile(runMigration1), "00002_posts_table.sql": newMapFile(runMigration2),