diff --git a/cli-definition.json b/cli-definition.json index 2efd91d27..9ee7b53bf 100644 --- a/cli-definition.json +++ b/cli-definition.json @@ -56,6 +56,16 @@ "use": "migrate ", "example": "migrate ./migrations", "flags": [ + { + "name": "backfill-batch-delay", + "description": "Duration of delay between batch backfills (eg. 1s, 1000ms)", + "default": "0s" + }, + { + "name": "backfill-batch-size", + "description": "Number of rows backfilled in each batch", + "default": "1000" + }, { "name": "complete", "shorthand": "c", @@ -112,6 +122,16 @@ "use": "start ", "example": "", "flags": [ + { + "name": "backfill-batch-delay", + "description": "Duration of delay between batch backfills (eg. 1s, 1000ms)", + "default": "0s" + }, + { + "name": "backfill-batch-size", + "description": "Number of rows backfilled in each batch", + "default": "1000" + }, { "name": "complete", "shorthand": "c", @@ -141,16 +161,6 @@ } ], "flags": [ - { - "name": "backfill-batch-delay", - "description": "Duration of delay between batch backfills (eg. 1s, 1000ms)", - "default": "0s" - }, - { - "name": "backfill-batch-size", - "description": "Number of rows backfilled in each batch", - "default": "1000" - }, { "name": "lock-timeout", "description": "Postgres lock timeout in milliseconds for pgroll DDL operations", diff --git a/cmd/flags/flags.go b/cmd/flags/flags.go index b50d4bb95..50b1f8320 100644 --- a/cmd/flags/flags.go +++ b/cmd/flags/flags.go @@ -3,8 +3,6 @@ package flags import ( - "time" - "github.com/spf13/viper" ) @@ -24,10 +22,6 @@ func LockTimeout() int { return viper.GetInt("LOCK_TIMEOUT") } -func BackfillBatchSize() int { return viper.GetInt("BACKFILL_BATCH_SIZE") } - -func BackfillBatchDelay() time.Duration { return viper.GetDuration("BACKFILL_BATCH_DELAY") } - func SkipValidation() bool { return viper.GetBool("SKIP_VALIDATION") } func Role() string { diff --git a/cmd/migrate.go b/cmd/migrate.go index 307073423..cd2468397 100644 --- a/cmd/migrate.go +++ b/cmd/migrate.go @@ -5,12 +5,17 @@ package cmd import ( "fmt" "os" + "time" "github.com/spf13/cobra" + + "github.com/xataio/pgroll/pkg/backfill" ) func migrateCmd() *cobra.Command { var complete bool + var batchSize int + var batchDelay time.Duration migrateCmd := &cobra.Command{ Use: "migrate ", @@ -59,19 +64,26 @@ func migrateCmd() *cobra.Command { return nil } + backfillConfig := backfill.NewConfig( + backfill.WithBatchSize(batchSize), + backfill.WithBatchDelay(batchDelay), + ) + // Run all migrations after the latest version up to the final migration, // completing each one. for _, mig := range migs[:len(migs)-1] { - if err := runMigration(ctx, m, mig, true); err != nil { + if err := runMigration(ctx, m, mig, true, backfillConfig); err != nil { return fmt.Errorf("failed to run migration file %q: %w", mig.Name, err) } } // Run the final migration, completing it only if requested. - return runMigration(ctx, m, migs[len(migs)-1], complete) + return runMigration(ctx, m, migs[len(migs)-1], complete, backfillConfig) }, } + migrateCmd.Flags().IntVar(&batchSize, "backfill-batch-size", backfill.DefaultBatchSize, "Number of rows backfilled in each batch") + migrateCmd.Flags().DurationVar(&batchDelay, "backfill-batch-delay", backfill.DefaultDelay, "Duration of delay between batch backfills (eg. 1s, 1000ms)") migrateCmd.Flags().BoolVarP(&complete, "complete", "c", false, "complete the final migration rather than leaving it active") return migrateCmd diff --git a/cmd/root.go b/cmd/root.go index f7e23bca3..e53e04a4b 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -22,8 +22,6 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) { stateSchema := flags.StateSchema() lockTimeout := flags.LockTimeout() role := flags.Role() - backfillBatchSize := flags.BackfillBatchSize() - backfillBatchDelay := flags.BackfillBatchDelay() skipValidation := flags.SkipValidation() state, err := state.New(ctx, pgURL, stateSchema) @@ -34,8 +32,6 @@ func NewRoll(ctx context.Context) (*roll.Roll, error) { return roll.New(ctx, pgURL, schema, state, roll.WithLockTimeoutMs(lockTimeout), roll.WithRole(role), - roll.WithBackfillBatchSize(backfillBatchSize), - roll.WithBackfillBatchDelay(backfillBatchDelay), roll.WithSkipValidation(skipValidation), ) } @@ -54,16 +50,12 @@ func Prepare() *cobra.Command { rootCmd.PersistentFlags().String("schema", "public", "Postgres schema to use for the migration") rootCmd.PersistentFlags().String("pgroll-schema", "pgroll", "Postgres schema to use for pgroll internal state") rootCmd.PersistentFlags().Int("lock-timeout", 500, "Postgres lock timeout in milliseconds for pgroll DDL operations") - rootCmd.PersistentFlags().Int("backfill-batch-size", roll.DefaultBackfillBatchSize, "Number of rows backfilled in each batch") - rootCmd.PersistentFlags().Duration("backfill-batch-delay", roll.DefaultBackfillDelay, "Duration of delay between batch backfills (eg. 1s, 1000ms)") rootCmd.PersistentFlags().String("role", "", "Optional postgres role to set when executing migrations") viper.BindPFlag("PG_URL", rootCmd.PersistentFlags().Lookup("postgres-url")) viper.BindPFlag("SCHEMA", rootCmd.PersistentFlags().Lookup("schema")) viper.BindPFlag("STATE_SCHEMA", rootCmd.PersistentFlags().Lookup("pgroll-schema")) viper.BindPFlag("LOCK_TIMEOUT", rootCmd.PersistentFlags().Lookup("lock-timeout")) - viper.BindPFlag("BACKFILL_BATCH_SIZE", rootCmd.PersistentFlags().Lookup("backfill-batch-size")) - viper.BindPFlag("BACKFILL_BATCH_DELAY", rootCmd.PersistentFlags().Lookup("backfill-batch-delay")) viper.BindPFlag("ROLE", rootCmd.PersistentFlags().Lookup("role")) // register subcommands diff --git a/cmd/start.go b/cmd/start.go index a06d02212..f8c47e15b 100644 --- a/cmd/start.go +++ b/cmd/start.go @@ -7,18 +7,22 @@ import ( "fmt" "math" "os" + "time" "github.com/pterm/pterm" "github.com/spf13/cobra" "github.com/spf13/viper" "github.com/xataio/pgroll/cmd/flags" + "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" ) func startCmd() *cobra.Command { var complete bool + var batchSize int + var batchDelay time.Duration startCmd := &cobra.Command{ Use: "start ", @@ -34,30 +38,37 @@ func startCmd() *cobra.Command { } defer m.Close() - return runMigrationFromFile(cmd.Context(), m, fileName, complete) + c := backfill.NewConfig( + backfill.WithBatchSize(batchSize), + backfill.WithBatchDelay(batchDelay), + ) + + return runMigrationFromFile(cmd.Context(), m, fileName, complete, c) }, } + startCmd.Flags().IntVar(&batchSize, "backfill-batch-size", backfill.DefaultBatchSize, "Number of rows backfilled in each batch") + startCmd.Flags().DurationVar(&batchDelay, "backfill-batch-delay", backfill.DefaultDelay, "Duration of delay between batch backfills (eg. 1s, 1000ms)") startCmd.Flags().BoolVarP(&complete, "complete", "c", false, "Mark the migration as complete") - startCmd.Flags().BoolP("skip-validation", "s", false, "skip migration validation") + viper.BindPFlag("SKIP_VALIDATION", startCmd.Flags().Lookup("skip-validation")) return startCmd } -func runMigrationFromFile(ctx context.Context, m *roll.Roll, fileName string, complete bool) error { +func runMigrationFromFile(ctx context.Context, m *roll.Roll, fileName string, complete bool, c *backfill.Config) error { migration, err := readMigration(fileName) if err != nil { return err } - return runMigration(ctx, m, migration, complete) + return runMigration(ctx, m, migration, complete, c) } -func runMigration(ctx context.Context, m *roll.Roll, migration *migrations.Migration, complete bool) error { +func runMigration(ctx context.Context, m *roll.Roll, migration *migrations.Migration, complete bool, c *backfill.Config) error { sp, _ := pterm.DefaultSpinner.WithText("Starting migration...").Start() - cb := func(n int64, total int64) { + c.AddCallback(func(n int64, total int64) { if total > 0 { percent := float64(n) / float64(total) * 100 // Percent can be > 100 if we're on the last batch in which case we still want to display 100. @@ -66,9 +77,9 @@ func runMigration(ctx context.Context, m *roll.Roll, migration *migrations.Migra } else { sp.UpdateText(fmt.Sprintf("%d records complete...", n)) } - } + }) - err := m.Start(ctx, migration, cb) + err := m.Start(ctx, migration, c) if err != nil { sp.Fail(fmt.Sprintf("Failed to start migration: %s", err)) return err diff --git a/internal/benchmarks/benchmarks_test.go b/internal/benchmarks/benchmarks_test.go index 66b1c0052..f610ef218 100644 --- a/internal/benchmarks/benchmarks_test.go +++ b/internal/benchmarks/benchmarks_test.go @@ -17,6 +17,7 @@ import ( "github.com/stretchr/testify/require" "github.com/xataio/pgroll/internal/testutils" + "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" ) @@ -71,7 +72,7 @@ func BenchmarkBackfill(b *testing.B) { // Backfill b.StartTimer() - require.NoError(b, mig.Start(ctx, &migAlterColumn)) + require.NoError(b, mig.Start(ctx, &migAlterColumn, backfill.NewConfig())) require.NoError(b, mig.Complete(ctx)) b.StopTimer() b.Logf("Backfilled %d rows in %s", rowCount, b.Elapsed()) @@ -132,7 +133,7 @@ func BenchmarkWriteAmplification(b *testing.B) { setupInitialTable(b, ctx, testSchema, mig, db, rowCount) // Start the migration - require.NoError(b, mig.Start(ctx, &migAlterColumn)) + require.NoError(b, mig.Start(ctx, &migAlterColumn, backfill.NewConfig())) b.Cleanup(func() { // Finish the migration require.NoError(b, mig.Complete(ctx)) @@ -211,7 +212,7 @@ func setupInitialTable(tb testing.TB, ctx context.Context, testSchema string, mi } // Setup - require.NoError(tb, mig.Start(ctx, &migCreateTable)) + require.NoError(tb, mig.Start(ctx, &migCreateTable, backfill.NewConfig())) require.NoError(tb, mig.Complete(ctx)) seed(tb, rowCount, db) } diff --git a/pkg/backfill/backfill.go b/pkg/backfill/backfill.go index ea9923d05..b98e83a61 100644 --- a/pkg/backfill/backfill.go +++ b/pkg/backfill/backfill.go @@ -15,24 +15,18 @@ import ( ) type Backfill struct { - conn db.DB - batchSize int - batchDelay time.Duration - callbacks []CallbackFn + conn db.DB + *Config } type CallbackFn func(done int64, total int64) // New creates a new backfill operation with the given options. The backfill is // not started until `Start` is invoked. -func New(conn db.DB, opts ...OptionFn) *Backfill { +func New(conn db.DB, c *Config) *Backfill { b := &Backfill{ - conn: conn, - batchSize: 1000, - } - - for _, opt := range opts { - opt(b) + conn: conn, + Config: c, } return b diff --git a/pkg/backfill/config.go b/pkg/backfill/config.go new file mode 100644 index 000000000..6a5394810 --- /dev/null +++ b/pkg/backfill/config.go @@ -0,0 +1,53 @@ +// SPDX-License-Identifier: Apache-2.0 + +package backfill + +import ( + "time" +) + +type Config struct { + batchSize int + batchDelay time.Duration + callbacks []CallbackFn +} + +const ( + DefaultBatchSize int = 1000 + DefaultDelay time.Duration = 0 +) + +type OptionFn func(*Config) + +func NewConfig(opts ...OptionFn) *Config { + c := &Config{ + batchSize: DefaultBatchSize, + batchDelay: DefaultDelay, + callbacks: make([]CallbackFn, 0), + } + + for _, opt := range opts { + opt(c) + } + return c +} + +// WithBatchSize sets the batch size for the backfill operation. +func WithBatchSize(batchSize int) OptionFn { + return func(o *Config) { + o.batchSize = batchSize + } +} + +// WithBatchDelay sets the delay between batches for the backfill operation. +func WithBatchDelay(delay time.Duration) OptionFn { + return func(o *Config) { + o.batchDelay = delay + } +} + +// AddCallback adds a callback to the backfill operation. +// Callbacks are invoked after each batch is processed. +func (c *Config) AddCallback(fn CallbackFn) { + c.callbacks = append(c.callbacks, fn) +} diff --git a/pkg/backfill/options.go b/pkg/backfill/options.go deleted file mode 100644 index 67848c433..000000000 --- a/pkg/backfill/options.go +++ /dev/null @@ -1,29 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package backfill - -import "time" - -type OptionFn func(*Backfill) - -// WithBatchSize sets the batch size for the backfill operation. -func WithBatchSize(batchSize int) OptionFn { - return func(o *Backfill) { - o.batchSize = batchSize - } -} - -// WithBatchDelay sets the delay between batches for the backfill operation. -func WithBatchDelay(delay time.Duration) OptionFn { - return func(o *Backfill) { - o.batchDelay = delay - } -} - -// WithCallbacks sets the callbacks for the backfill operation. -// Callbacks are invoked after each batch is processed. -func WithCallbacks(cbs ...CallbackFn) OptionFn { - return func(o *Backfill) { - o.callbacks = cbs - } -} diff --git a/pkg/migrations/op_common_test.go b/pkg/migrations/op_common_test.go index 889e743b5..da36fcca8 100644 --- a/pkg/migrations/op_common_test.go +++ b/pkg/migrations/op_common_test.go @@ -18,6 +18,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/xataio/pgroll/internal/testutils" + "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" ) @@ -51,10 +52,11 @@ func ExecuteTests(t *testing.T, tests TestCases, opts ...roll.Option) { t.Run(tt.name, func(t *testing.T) { testutils.WithMigratorInSchemaAndConnectionToContainerWithOptions(t, testSchema, opts, func(mig *roll.Roll, db *sql.DB) { ctx := context.Background() + config := backfill.NewConfig() // run all migrations except the last one for i := 0; i < len(tt.migrations)-1; i++ { - if err := mig.Start(ctx, &tt.migrations[i]); err != nil { + if err := mig.Start(ctx, &tt.migrations[i], config); err != nil { t.Fatalf("Failed to start migration: %v", err) } @@ -64,7 +66,7 @@ func ExecuteTests(t *testing.T, tests TestCases, opts ...roll.Option) { } // start the last migration - err := mig.Start(ctx, &tt.migrations[len(tt.migrations)-1]) + err := mig.Start(ctx, &tt.migrations[len(tt.migrations)-1], config) if tt.wantStartErr != nil { if !errors.Is(err, tt.wantStartErr) { t.Fatalf("Expected error %q, got %q", tt.wantStartErr, err) @@ -98,7 +100,7 @@ func ExecuteTests(t *testing.T, tests TestCases, opts ...roll.Option) { } // re-start the last migration - if err := mig.Start(ctx, &tt.migrations[len(tt.migrations)-1]); err != nil { + if err := mig.Start(ctx, &tt.migrations[len(tt.migrations)-1], config); err != nil { t.Fatalf("Failed to start migration: %v", err) } diff --git a/pkg/roll/execute.go b/pkg/roll/execute.go index ad7537036..19c08589d 100644 --- a/pkg/roll/execute.go +++ b/pkg/roll/execute.go @@ -16,14 +16,14 @@ import ( ) // Start will apply the required changes to enable supporting the new schema version -func (m *Roll) Start(ctx context.Context, migration *migrations.Migration, cbs ...backfill.CallbackFn) error { +func (m *Roll) Start(ctx context.Context, migration *migrations.Migration, cfg *backfill.Config) error { tablesToBackfill, err := m.StartDDLOperations(ctx, migration) if err != nil { return err } // perform backfills for the tables that require it - return m.performBackfills(ctx, tablesToBackfill, cbs...) + return m.performBackfills(ctx, tablesToBackfill, cfg) } // StartDDLOperations performs the DDL operations for the migration. This does @@ -309,11 +309,8 @@ func (m *Roll) ensureView(ctx context.Context, version, name string, table *sche return nil } -func (m *Roll) performBackfills(ctx context.Context, tables []*schema.Table, cbs ...backfill.CallbackFn) error { - bf := backfill.New(m.pgConn, - backfill.WithBatchSize(m.backfillBatchSize), - backfill.WithBatchDelay(m.backfillBatchDelay), - backfill.WithCallbacks(cbs...)) +func (m *Roll) performBackfills(ctx context.Context, tables []*schema.Table, cfg *backfill.Config) error { + bf := backfill.New(m.pgConn, cfg) for _, table := range tables { if err := bf.Start(ctx, table); err != nil { diff --git a/pkg/roll/execute_test.go b/pkg/roll/execute_test.go index 4cf4df78f..4775f0d89 100644 --- a/pkg/roll/execute_test.go +++ b/pkg/roll/execute_test.go @@ -13,6 +13,7 @@ import ( "github.com/stretchr/testify/require" "github.com/xataio/pgroll/internal/testutils" + "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" "github.com/xataio/pgroll/pkg/state" @@ -33,7 +34,7 @@ func TestSchemaIsCreatedAfterMigrationStart(t *testing.T) { ctx := context.Background() version := "1_create_table" - if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp("table1")}}); err != nil { + if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp("table1")}}, backfill.NewConfig()); err != nil { t.Fatalf("Failed to start migration: %v", err) } @@ -53,7 +54,7 @@ func TestDisabledSchemaManagement(t *testing.T) { ctx := context.Background() version := "1_create_table" - if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp("table1")}}); err != nil { + if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp("table1")}}, backfill.NewConfig()); err != nil { t.Fatalf("Failed to start migration: %v", err) } @@ -68,7 +69,7 @@ func TestDisabledSchemaManagement(t *testing.T) { t.Fatalf("Failed to rollback migration: %v", err) } - if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp("table1")}}); err != nil { + if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp("table1")}}, backfill.NewConfig()); err != nil { t.Fatalf("Failed to start migration again: %v", err) } @@ -94,13 +95,13 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { secondVersion = "2_create_table" ) - if err := mig.Start(ctx, &migrations.Migration{Name: firstVersion, Operations: migrations.Operations{createTableOp("table1")}}); err != nil { + if err := mig.Start(ctx, &migrations.Migration{Name: firstVersion, Operations: migrations.Operations{createTableOp("table1")}}, backfill.NewConfig()); err != nil { t.Fatalf("Failed to start first migration: %v", err) } if err := mig.Complete(ctx); err != nil { t.Fatalf("Failed to complete first migration: %v", err) } - if err := mig.Start(ctx, &migrations.Migration{Name: secondVersion, Operations: migrations.Operations{createTableOp("table2")}}); err != nil { + if err := mig.Start(ctx, &migrations.Migration{Name: secondVersion, Operations: migrations.Operations{createTableOp("table2")}}, backfill.NewConfig()); err != nil { t.Fatalf("Failed to start second migration: %v", err) } if err := mig.Complete(ctx); err != nil { @@ -125,7 +126,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { ) // Run the first pgroll migration - if err := mig.Start(ctx, &migrations.Migration{Name: firstVersion, Operations: migrations.Operations{createTableOp("table1")}}); err != nil { + if err := mig.Start(ctx, &migrations.Migration{Name: firstVersion, Operations: migrations.Operations{createTableOp("table1")}}, backfill.NewConfig()); err != nil { t.Fatalf("Failed to start first migration: %v", err) } if err := mig.Complete(ctx); err != nil { @@ -139,7 +140,7 @@ func TestPreviousVersionIsDroppedAfterMigrationCompletion(t *testing.T) { } // Run the second pgroll migration - if err := mig.Start(ctx, &migrations.Migration{Name: secondVersion, Operations: migrations.Operations{createTableOp("table2")}}); err != nil { + if err := mig.Start(ctx, &migrations.Migration{Name: secondVersion, Operations: migrations.Operations{createTableOp("table2")}}, backfill.NewConfig()); err != nil { t.Fatalf("Failed to start second migration: %v", err) } if err := mig.Complete(ctx); err != nil { @@ -165,7 +166,7 @@ func TestNoVersionSchemaForRawSQLMigrationsOptionIsRespected(t *testing.T) { ctx := context.Background() // Apply a create table migration - err := mig.Start(ctx, &migrations.Migration{Name: "01_create_table", Operations: migrations.Operations{createTableOp("table1")}}) + err := mig.Start(ctx, &migrations.Migration{Name: "01_create_table", Operations: migrations.Operations{createTableOp("table1")}}, backfill.NewConfig()) require.NoError(t, err) err = mig.Complete(ctx) require.NoError(t, err) @@ -176,13 +177,13 @@ func TestNoVersionSchemaForRawSQLMigrationsOptionIsRespected(t *testing.T) { Operations: migrations.Operations{&migrations.OpRawSQL{ Up: "CREATE TABLE table2(a int)", }}, - }) + }, backfill.NewConfig()) require.NoError(t, err) err = mig.Complete(ctx) require.NoError(t, err) // Start a third create table migration - err = mig.Start(ctx, &migrations.Migration{Name: "03_create_table", Operations: migrations.Operations{createTableOp("table3")}}) + err = mig.Start(ctx, &migrations.Migration{Name: "03_create_table", Operations: migrations.Operations{createTableOp("table3")}}, backfill.NewConfig()) require.NoError(t, err) // The previous version is migration 01 if raw SQL migrations are ignored @@ -218,7 +219,7 @@ func TestSchemaIsDroppedAfterMigrationRollback(t *testing.T) { ctx := context.Background() version := "1_create_table" - if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp("table1")}}); err != nil { + if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp("table1")}}, backfill.NewConfig()); err != nil { t.Fatalf("Failed to start migration: %v", err) } if err := mig.Rollback(ctx); err != nil { @@ -257,7 +258,7 @@ func TestRollbackOnMigrationStartFailure(t *testing.T) { }, }, }, - }) + }, backfill.NewConfig()) assert.Error(t, err) // ensure that there is no active migration @@ -277,7 +278,7 @@ func TestRollbackOnMigrationStartFailure(t *testing.T) { err := mig.Start(ctx, &migrations.Migration{ Name: "01_create_table", Operations: migrations.Operations{createTableOp("table1")}, - }) + }, backfill.NewConfig()) assert.NoError(t, err) // complete the migration @@ -301,7 +302,7 @@ func TestRollbackOnMigrationStartFailure(t *testing.T) { Down: "invalid", }, }, - }) + }, backfill.NewConfig()) assert.Error(t, err) // Ensure that there is no active migration @@ -321,10 +322,14 @@ func TestSchemaOptionIsRespected(t *testing.T) { const version1 = "1_create_table" const version2 = "2_create_another_table" - if err := mig.Start(ctx, &migrations.Migration{ - Name: version1, - Operations: migrations.Operations{createTableOp("table1")}, - }); err != nil { + if err := mig.Start( + ctx, + &migrations.Migration{ + Name: version1, + Operations: migrations.Operations{createTableOp("table1")}, + }, + backfill.NewConfig(), + ); err != nil { t.Fatalf("Failed to start migration: %v", err) } if err := mig.Complete(ctx); err != nil { @@ -354,7 +359,9 @@ func TestSchemaOptionIsRespected(t *testing.T) { if err := mig.Start(ctx, &migrations.Migration{ Name: version2, Operations: migrations.Operations{createTableOp("table2")}, - }); err != nil { + }, + backfill.NewConfig(), + ); err != nil { t.Fatalf("Failed to start migration: %v", err) } if err := mig.Complete(ctx); err != nil { @@ -408,7 +415,7 @@ func TestMigrationDDLIsRetriedOnLockTimeouts(t *testing.T) { err = mig.Start(ctx, &migrations.Migration{ Name: "01_add_column", Operations: migrations.Operations{addColumnOp("table1")}, - }) + }, backfill.NewConfig()) require.NoError(t, err) }) } @@ -425,7 +432,7 @@ func TestViewsAreCreatedWithSecurityInvokerTrue(t *testing.T) { } // Start and complete a migration to create a simple `users` table - if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp("users")}}); err != nil { + if err := mig.Start(ctx, &migrations.Migration{Name: version, Operations: migrations.Operations{createTableOp("users")}}, backfill.NewConfig()); err != nil { t.Fatalf("Failed to start migration: %v", err) } if err := mig.Complete(ctx); err != nil { @@ -516,7 +523,7 @@ func TestStatusMethodReturnsCorrectStatus(t *testing.T) { err = mig.Start(ctx, &migrations.Migration{ Name: "01_create_table", Operations: []migrations.Operation{createTableOp("table1")}, - }) + }, backfill.NewConfig()) assert.NoError(t, err) // Get the migration status @@ -549,7 +556,7 @@ func TestStatusMethodReturnsCorrectStatus(t *testing.T) { err = mig.Start(ctx, &migrations.Migration{ Name: "01_create_table", Operations: []migrations.Operation{createTableOp("table1")}, - }) + }, backfill.NewConfig()) assert.NoError(t, err) err = mig.Complete(ctx) assert.NoError(t, err) @@ -577,7 +584,7 @@ func TestRoleIsRespected(t *testing.T) { err := mig.Start(ctx, &migrations.Migration{ Name: "01_create_table", Operations: migrations.Operations{createTableOp("table1")}, - }) + }, backfill.NewConfig()) assert.NoError(t, err) // Complete the create table migration @@ -628,7 +635,7 @@ func TestMigrationHooksAreInvoked(t *testing.T) { err := mig.Start(ctx, &migrations.Migration{ Name: "01_create_table", Operations: migrations.Operations{createTableOp("table1")}, - }) + }, backfill.NewConfig()) assert.NoError(t, err) // Ensure that both the before_start_ddl and after_start_ddl tables were created @@ -664,6 +671,9 @@ func TestCallbacksAreInvokedOnMigrationStart(t *testing.T) { invoked := false cb := func(n, total int64) { invoked = true } + backfillConfig := backfill.NewConfig() + backfillConfig.AddCallback(cb) + // Start a migration that requires a backfill err = mig.Start(ctx, &migrations.Migration{ Name: "02_change_type", @@ -676,7 +686,7 @@ func TestCallbacksAreInvokedOnMigrationStart(t *testing.T) { Down: "name", }, }, - }, cb) + }, backfillConfig) require.NoError(t, err) // Ensure that the callback was invoked @@ -718,7 +728,7 @@ func TestSQLTransformerOptionIsUsedWhenCreatingTriggers(t *testing.T) { err := mig.Start(ctx, &migrations.Migration{ Name: "01_create_table", Operations: migrations.Operations{createTableOp("table1")}, - }) + }, backfill.NewConfig()) require.NoError(t, err) // Complete the migration @@ -743,7 +753,7 @@ func TestSQLTransformerOptionIsUsedWhenCreatingTriggers(t *testing.T) { }, }, }, - }) + }, backfill.NewConfig()) require.NoError(t, err) // Complete the migration @@ -774,7 +784,7 @@ func TestSQLTransformerOptionIsUsedWhenCreatingTriggers(t *testing.T) { err := mig.Start(ctx, &migrations.Migration{ Name: "01_create_table", Operations: migrations.Operations{createTableOp("table1")}, - }) + }, backfill.NewConfig()) require.NoError(t, err) // Complete the migration @@ -799,7 +809,7 @@ func TestSQLTransformerOptionIsUsedWhenCreatingTriggers(t *testing.T) { }, }, }, - }) + }, backfill.NewConfig()) // Ensure that the start phase has failed with a SQL transformer error require.ErrorIs(t, err, testutils.ErrMockSQLTransformer) }) @@ -830,7 +840,7 @@ func TestWithSearchPathOptionIsRespected(t *testing.T) { Up: "SELECT say_hello()", }, }, - }) + }, backfill.NewConfig()) require.NoError(t, err) // Complete the migration diff --git a/pkg/roll/latest_test.go b/pkg/roll/latest_test.go index 52c9630ac..0930cae31 100644 --- a/pkg/roll/latest_test.go +++ b/pkg/roll/latest_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xataio/pgroll/internal/testutils" + "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" ) @@ -65,7 +66,7 @@ func TestLatestVersionRemote(t *testing.T) { Operations: migrations.Operations{ &migrations.OpRawSQL{Up: "SELECT 1"}, }, - }) + }, backfill.NewConfig()) require.NoError(t, err) err = m.Complete(ctx) require.NoError(t, err) diff --git a/pkg/roll/options.go b/pkg/roll/options.go index ea76c3576..bc5a0fedc 100644 --- a/pkg/roll/options.go +++ b/pkg/roll/options.go @@ -3,8 +3,6 @@ package roll import ( - "time" - "github.com/xataio/pgroll/pkg/migrations" ) @@ -27,12 +25,6 @@ type options struct { // additional entries to add to the search_path during migration execution searchPath []string - // the number of rows to backfill in each batch - backfillBatchSize int - - // the duration to delay after each batch is run - backfillBatchDelay time.Duration - // whether to skip validation skipValidation bool @@ -111,20 +103,6 @@ func WithSearchPath(schemas ...string) Option { } } -// WithBackfillBatchSize sets the number of rows backfilled in each batch. -func WithBackfillBatchSize(batchSize int) Option { - return func(o *options) { - o.backfillBatchSize = batchSize - } -} - -// WithBackfillBatchDelay sets the delay after each batch is run. -func WithBackfillBatchDelay(delay time.Duration) Option { - return func(o *options) { - o.backfillBatchDelay = delay - } -} - // WithSkipValidation controls whether or not to perform validation on // migrations. If set to true, validation will be skipped. func WithSkipValidation(skip bool) Option { diff --git a/pkg/roll/roll.go b/pkg/roll/roll.go index 8fce2d5d4..06897d293 100644 --- a/pkg/roll/roll.go +++ b/pkg/roll/roll.go @@ -7,7 +7,6 @@ import ( "database/sql" "fmt" "strings" - "time" "github.com/lib/pq" @@ -19,9 +18,7 @@ import ( type PGVersion int const ( - PGVersion15 PGVersion = 15 - DefaultBackfillBatchSize int = 1000 - DefaultBackfillDelay time.Duration = 0 + PGVersion15 PGVersion = 15 ) var ErrMismatchedMigration = fmt.Errorf("remote migration does not match local migration") @@ -42,10 +39,7 @@ type Roll struct { state *state.State pgVersion PGVersion sqlTransformer migrations.SQLTransformer - - backfillBatchSize int - backfillBatchDelay time.Duration - skipValidation bool + skipValidation bool } // New creates a new Roll instance @@ -54,9 +48,6 @@ func New(ctx context.Context, pgURL, schema string, state *state.State, opts ... for _, o := range opts { o(rollOpts) } - if rollOpts.backfillBatchSize <= 0 { - rollOpts.backfillBatchSize = DefaultBackfillBatchSize - } conn, err := setupConn(ctx, pgURL, schema, *rollOpts) if err != nil { @@ -85,8 +76,6 @@ func New(ctx context.Context, pgURL, schema string, state *state.State, opts ... disableVersionSchemas: rollOpts.disableVersionSchemas, noVersionSchemaForRawSQL: rollOpts.noVersionSchemaForRawSQL, migrationHooks: rollOpts.migrationHooks, - backfillBatchSize: rollOpts.backfillBatchSize, - backfillBatchDelay: rollOpts.backfillBatchDelay, skipValidation: rollOpts.skipValidation, }, nil } diff --git a/pkg/roll/unapplied_test.go b/pkg/roll/unapplied_test.go index 51ee17225..64d60f1d5 100644 --- a/pkg/roll/unapplied_test.go +++ b/pkg/roll/unapplied_test.go @@ -14,6 +14,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/xataio/pgroll/internal/testutils" + "github.com/xataio/pgroll/pkg/backfill" "github.com/xataio/pgroll/pkg/migrations" "github.com/xataio/pgroll/pkg/roll" ) @@ -57,7 +58,7 @@ func TestUnappliedMigrations(t *testing.T) { require.NoError(t, err) // Apply the first migration - err = roll.Start(ctx, &migration) + err = roll.Start(ctx, &migration, backfill.NewConfig()) require.NoError(t, err) err = roll.Complete(ctx) require.NoError(t, err) @@ -89,7 +90,7 @@ func TestUnappliedMigrations(t *testing.T) { err := json.Unmarshal(fs[filename].Data, &migration) require.NoError(t, err) - err = roll.Start(ctx, &migration) + err = roll.Start(ctx, &migration, backfill.NewConfig()) require.NoError(t, err) err = roll.Complete(ctx) require.NoError(t, err) @@ -119,7 +120,7 @@ func TestUnappliedMigrations(t *testing.T) { Operations: migrations.Operations{ &migrations.OpRawSQL{Up: "SELECT 1"}, }, - }) + }, backfill.NewConfig()) require.NoError(t, err) err = m.Complete(ctx) require.NoError(t, err)