Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate context in migrations #596

Merged
merged 2 commits into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/datastore/crdb/crdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ func (cds *crdbDatastore) IsReady(ctx context.Context) (bool, error) {
}
defer currentRevision.Close()

version, err := currentRevision.Version()
version, err := currentRevision.Version(ctx)
if err != nil {
return false, err
}
Expand Down
8 changes: 4 additions & 4 deletions internal/datastore/crdb/migrations/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ func NewCRDBDriver(url string) (*CRDBDriver, error) {

// Version returns the version of the schema to which the connected database
// has been migrated.
func (apd *CRDBDriver) Version() (string, error) {
func (apd *CRDBDriver) Version(ctx context.Context) (string, error) {
var loaded string

if err := apd.db.QueryRow(context.Background(), queryLoadVersion).Scan(&loaded); err != nil {
if err := apd.db.QueryRow(ctx, queryLoadVersion).Scan(&loaded); err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == postgresMissingTableErrorCode {
return "", nil
Expand All @@ -62,8 +62,8 @@ func (apd *CRDBDriver) Version() (string, error) {

// WriteVersion overwrites the value stored to track the version of the
// database schema.
func (apd *CRDBDriver) WriteVersion(version, replaced string) error {
result, err := apd.db.Exec(context.Background(), queryWriteVersion, version, replaced)
func (apd *CRDBDriver) WriteVersion(ctx context.Context, version, replaced string) error {
result, err := apd.db.Exec(ctx, queryWriteVersion, version, replaced)
if err != nil {
return fmt.Errorf("unable to update version row: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/datastore/mysql/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ func (mds *Datastore) IsReady(ctx context.Context) (bool, error) {
return false, err
}

currentMigrationRevision, err := mds.driver.Version()
currentMigrationRevision, err := mds.driver.Version(ctx)
if err != nil {
return false, err
}
Expand Down
12 changes: 6 additions & 6 deletions internal/datastore/mysql/datastore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -630,14 +630,14 @@ func TestMySQLMigrations(t *testing.T) {
db := datastoreDB(t, false)
migrationDriver := migrations.NewMySQLDriverFromDB(db, "")

version, err := migrationDriver.Version()
version, err := migrationDriver.Version(context.Background())
req.NoError(err)
req.Equal("", version)

err = migrations.Manager.Run(migrationDriver, migrate.Head, migrate.LiveRun)
err = migrations.Manager.Run(context.Background(), migrationDriver, migrate.Head, migrate.LiveRun)
req.NoError(err)

version, err = migrationDriver.Version()
version, err = migrationDriver.Version(context.Background())
req.NoError(err)

headVersion, err := migrations.Manager.HeadRevision()
Expand All @@ -652,14 +652,14 @@ func TestMySQLMigrationsWithPrefix(t *testing.T) {
db := datastoreDB(t, false)
migrationDriver := migrations.NewMySQLDriverFromDB(db, prefix)

version, err := migrationDriver.Version()
version, err := migrationDriver.Version(context.Background())
req.NoError(err)
req.Equal("", version)

err = migrations.Manager.Run(migrationDriver, migrate.Head, migrate.LiveRun)
err = migrations.Manager.Run(context.Background(), migrationDriver, migrate.Head, migrate.LiveRun)
req.NoError(err)

version, err = migrationDriver.Version()
version, err = migrationDriver.Version(context.Background())
req.NoError(err)

headVersion, err := migrations.Manager.HeadRevision()
Expand Down
10 changes: 5 additions & 5 deletions internal/datastore/mysql/migrations/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,21 @@ func columnNameToRevision(columnName string) (string, bool) {

// Version returns the version of the schema to which the connected database
// has been migrated.
func (driver *MySQLDriver) Version() (string, error) {
func (driver *MySQLDriver) Version(ctx context.Context) (string, error) {
query, args, err := sb.Select("*").From(driver.migrationVersion()).ToSql()
if err != nil {
return "", fmt.Errorf("unable to load driver migration revision: %w", err)
}

rows, err := driver.db.Query(query, args...)
rows, err := driver.db.QueryContext(ctx, query, args...)
if err != nil {
var mysqlError *sqlDriver.MySQLError
if errors.As(err, &mysqlError) && mysqlError.Number == mysqlMissingTableErrorNumber {
return "", nil
}
return "", fmt.Errorf("unable to load driver migration revision: %w", err)
}
defer LogOnError(context.Background(), rows.Close)
defer LogOnError(ctx, rows.Close)
if rows.Err() != nil {
return "", fmt.Errorf("unable to load driver migration revision: %w", err)
}
Expand All @@ -100,13 +100,13 @@ func (driver *MySQLDriver) Version() (string, error) {

// WriteVersion overwrites the _meta_version_ column name which encodes the version
// of the database schema.
func (driver *MySQLDriver) WriteVersion(version, replaced string) error {
func (driver *MySQLDriver) WriteVersion(ctx context.Context, version, replaced string) error {
stmt := fmt.Sprintf("ALTER TABLE %s CHANGE %s %s VARCHAR(255) NOT NULL",
driver.migrationVersion(),
revisionToColumnName(replaced),
revisionToColumnName(version),
)
if _, err := driver.db.Exec(stmt); err != nil {
if _, err := driver.db.ExecContext(ctx, stmt); err != nil {
return fmt.Errorf("unable to version: %w", err)
}

Expand Down
8 changes: 4 additions & 4 deletions internal/datastore/postgres/migrations/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ func NewAlembicPostgresDriver(url string) (*AlembicPostgresDriver, error) {

// Version returns the version of the schema to which the connected database
// has been migrated.
func (apd *AlembicPostgresDriver) Version() (string, error) {
func (apd *AlembicPostgresDriver) Version(ctx context.Context) (string, error) {
var loaded string

if err := apd.db.QueryRow(context.Background(), "SELECT version_num from alembic_version").Scan(&loaded); err != nil {
if err := apd.db.QueryRow(ctx, "SELECT version_num from alembic_version").Scan(&loaded); err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) && pgErr.Code == postgresMissingTableErrorCode {
return "", nil
Expand All @@ -55,9 +55,9 @@ func (apd *AlembicPostgresDriver) Version() (string, error) {

// WriteVersion overwrites the value stored to track the version of the
// database schema.
func (apd *AlembicPostgresDriver) WriteVersion(version, replaced string) error {
func (apd *AlembicPostgresDriver) WriteVersion(ctx context.Context, version, replaced string) error {
result, err := apd.db.Exec(
context.Background(),
ctx,
"UPDATE alembic_version SET version_num=$1 WHERE version_num=$2",
version,
replaced,
Expand Down
2 changes: 1 addition & 1 deletion internal/datastore/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ func (pgd *pgDatastore) IsReady(ctx context.Context) (bool, error) {
}
defer currentRevision.Close()

version, err := currentRevision.Version()
version, err := currentRevision.Version(ctx)
if err != nil {
return false, err
}
Expand Down
8 changes: 4 additions & 4 deletions internal/datastore/spanner/migrations/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ func NewSpannerDriver(database, credentialsFilePath, emulatorHost string) (Spann
return SpannerMigrationDriver{client, adminClient}, nil
}

func (smd SpannerMigrationDriver) Version() (string, error) {
func (smd SpannerMigrationDriver) Version(ctx context.Context) (string, error) {
rows := smd.client.Single().Read(
context.Background(),
ctx,
tableSchemaVersion,
spanner.AllKeys(),
[]string{colVersionNum},
Expand All @@ -72,8 +72,8 @@ func (smd SpannerMigrationDriver) Version() (string, error) {
return schemaRevision, nil
}

func (smd SpannerMigrationDriver) WriteVersion(version, replaced string) error {
_, err := smd.client.ReadWriteTransaction(context.Background(), func(c context.Context, rwt *spanner.ReadWriteTransaction) error {
func (smd SpannerMigrationDriver) WriteVersion(ctx context.Context, version, replaced string) error {
_, err := smd.client.ReadWriteTransaction(ctx, func(c context.Context, rwt *spanner.ReadWriteTransaction) error {
return rwt.BufferWrite([]*spanner.Mutation{
spanner.Delete(tableSchemaVersion, spanner.KeySetFromKeys(spanner.Key{replaced})),
spanner.Insert(tableSchemaVersion, []string{colVersionNum}, []interface{}{version}),
Expand Down
2 changes: 1 addition & 1 deletion internal/datastore/spanner/spanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func (sd spannerDatastore) IsReady(ctx context.Context) (bool, error) {
}
defer currentRevision.Close()

version, err := currentRevision.Version()
version, err := currentRevision.Version(ctx)
if err != nil {
return false, err
}
Expand Down
2 changes: 1 addition & 1 deletion internal/testserver/datastore/crdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func (r *crdbTester) NewDatastore(t testing.TB, initFunc InitFunc) datastore.Dat

migrationDriver, err := crdbmigrations.NewCRDBDriver(connectStr)
require.NoError(t, err)
require.NoError(t, crdbmigrations.CRDBMigrations.Run(migrationDriver, migrate.Head, migrate.LiveRun))
require.NoError(t, crdbmigrations.CRDBMigrations.Run(context.Background(), migrationDriver, migrate.Head, migrate.LiveRun))

return initFunc("cockroachdb", connectStr)
}
5 changes: 2 additions & 3 deletions internal/testserver/datastore/mysql.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package datastore

import (
"context"
"database/sql"
"fmt"
"testing"
Expand Down Expand Up @@ -109,9 +110,7 @@ func (mb *mysqlTester) NewDatabase(t testing.TB) string {
func (mb *mysqlTester) runMigrate(t testing.TB, dsn string) {
driver, err := migrations.NewMySQLDriverFromDSN(dsn, mb.options.Prefix)
require.NoError(t, err, "failed to create migration driver: %s", err)
err = migrations.Manager.Run(driver, migrate.Head, migrate.LiveRun)
require.NoError(t, err, "failed to run migration: %s", err)
err = migrations.Manager.Run(driver, migrate.Head, migrate.LiveRun)
err = migrations.Manager.Run(context.Background(), driver, migrate.Head, migrate.LiveRun)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the duplication here might be in error

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, thanks ! This must have slipped through my attention while resolving the rebase conflicts ...

require.NoError(t, err, "failed to run migration: %s", err)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/testserver/datastore/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ func (b *postgresTester) NewDatastore(t testing.TB, initFunc InitFunc) datastore

migrationDriver, err := pgmigrations.NewAlembicPostgresDriver(connectStr)
require.NoError(t, err)
require.NoError(t, pgmigrations.DatabaseMigrations.Run(migrationDriver, migrate.Head, migrate.LiveRun))
require.NoError(t, pgmigrations.DatabaseMigrations.Run(context.Background(), migrationDriver, migrate.Head, migrate.LiveRun))

return initFunc("postgres", connectStr)
}
2 changes: 1 addition & 1 deletion internal/testserver/datastore/spanner.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (b *spannerTest) NewDatastore(t testing.TB, initFunc InitFunc) datastore.Da
migrationDriver, err := migrations.NewSpannerDriver(db, "", os.Getenv("SPANNER_EMULATOR_HOST"))
require.NoError(t, err)

err = migrations.SpannerMigrations.Run(migrationDriver, migrate.Head, migrate.LiveRun)
err = migrations.SpannerMigrations.Run(context.Background(), migrationDriver, migrate.Head, migrate.LiveRun)
require.NoError(t, err)

return initFunc("spanner", db)
Expand Down
3 changes: 2 additions & 1 deletion pkg/cmd/migrate.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"context"
"fmt"

"github.com/fatih/color"
Expand Down Expand Up @@ -96,7 +97,7 @@ func migrateRun(cmd *cobra.Command, args []string) error {
targetRevision := args[0]

log.Info().Str("targetRevision", targetRevision).Msg("running migrations")
if err := manager.Run(migrationDriver, targetRevision, migrate.LiveRun); err != nil {
if err := manager.Run(context.Background(), migrationDriver, targetRevision, migrate.LiveRun); err != nil {
log.Fatal().Err(err).Msg("unable to complete requested migrations")
}

Expand Down
13 changes: 7 additions & 6 deletions pkg/migrate/migrate.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package migrate

import (
"context"
"errors"
"fmt"
"reflect"
Expand All @@ -27,10 +28,10 @@ type Driver interface {
// Version returns the current version of the schema in the backing datastore.
// If the datastore is brand new, version should return the empty string without
// an error.
Version() (string, error)
Version(ctx context.Context) (string, error)

// WriteVersion records the newly migrated version to the backing datastore.
WriteVersion(version, replaced string) error
WriteVersion(ctx context.Context, version, replaced string) error

// Close frees up any resources in use by the driver.
Close() error
Expand Down Expand Up @@ -81,8 +82,8 @@ func (m *Manager) Register(version, replaces string, up interface{}) error {

// Run will actually perform the necessary migrations to bring the backing datastore
// from its current revision to the specified revision.
func (m *Manager) Run(driver Driver, throughRevision string, dryRun RunType) error {
starting, err := driver.Version()
func (m *Manager) Run(ctx context.Context, driver Driver, throughRevision string, dryRun RunType) error {
starting, err := driver.Version(ctx)
if err != nil {
return fmt.Errorf("unable to compute target revision: %w", err)
}
Expand Down Expand Up @@ -112,7 +113,7 @@ func (m *Manager) Run(driver Driver, throughRevision string, dryRun RunType) err
if !dryRun {
for _, migrationToRun := range toRun {
// Double check that the current version reported is the one we expect
currentVersion, err := driver.Version()
currentVersion, err := driver.Version(ctx)
if err != nil {
return fmt.Errorf("unable to load version from driver: %w", err)
}
Expand All @@ -131,7 +132,7 @@ func (m *Manager) Run(driver Driver, throughRevision string, dryRun RunType) err
return fmt.Errorf("error running migration up function: %v", errArg)
}

if err := driver.WriteVersion(migrationToRun.version, migrationToRun.replaces); err != nil {
if err := driver.WriteVersion(ctx, migrationToRun.version, migrationToRun.replaces); err != nil {
return fmt.Errorf("error writing migration version to driver: %w", err)
}
}
Expand Down
5 changes: 3 additions & 2 deletions pkg/migrate/migrate_test.go
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
package migrate

import (
"context"
"testing"

"github.com/stretchr/testify/require"
)

type fakeDriver struct{}

func (*fakeDriver) Version() (string, error) {
func (*fakeDriver) Version(ctx context.Context) (string, error) {
return "", nil
}

func (*fakeDriver) WriteVersion(version, replaced string) error {
func (*fakeDriver) WriteVersion(ctx context.Context, version, replaced string) error {
return nil
}

Expand Down