Skip to content

Commit

Permalink
Merge pull request #247 from sarajmunjal/saraj/ctx
Browse files Browse the repository at this point in the history
Add context support to Exec methods
  • Loading branch information
rubenv authored Jun 5, 2023
2 parents 7084132 + fda37a1 commit 61ee1bf
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 8 deletions.
57 changes: 49 additions & 8 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package migrate

import (
"bytes"
"context"
"database/sql"
"errors"
"fmt"
Expand Down Expand Up @@ -429,12 +430,24 @@ type SqlExecutor interface {
//
// Returns the number of applied migrations.
func Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
return ExecMax(db, dialect, m, dir, 0)
return ExecMaxContext(context.Background(), db, dialect, m, dir, 0)
}

// Returns the number of applied migrations.
func (ms MigrationSet) Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
return ms.ExecMax(db, dialect, m, dir, 0)
return ms.ExecMaxContext(context.Background(), db, dialect, m, dir, 0)
}

// Execute a set of migrations with an input context.
//
// Returns the number of applied migrations.
func ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
return ExecMaxContext(ctx, db, dialect, m, dir, 0)
}

// Returns the number of applied migrations.
func (ms MigrationSet) ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
return ms.ExecMaxContext(ctx, db, dialect, m, dir, 0)
}

// Execute a set of migrations
Expand All @@ -446,50 +459,78 @@ func ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirecti
return migSet.ExecMax(db, dialect, m, dir, max)
}

// Execute a set of migrations with an input context.
//
// Will apply at most `max` migrations. Pass 0 for no limit (or use Exec).
//
// Returns the number of applied migrations.
func ExecMaxContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
return migSet.ExecMaxContext(ctx, db, dialect, m, dir, max)
}

// Execute a set of migrations
//
// Will apply at the target `version` of migration. Cannot be a negative value.
//
// Returns the number of applied migrations.
func ExecVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
return ExecVersionContext(context.Background(), db, dialect, m, dir, version)
}

// Execute a set of migrations with an input context.
//
// Will apply at the target `version` of migration. Cannot be a negative value.
//
// Returns the number of applied migrations.
func ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
if version < 0 {
return 0, fmt.Errorf("target version %d should not be negative", version)
}
return migSet.ExecVersion(db, dialect, m, dir, version)
return migSet.ExecVersionContext(ctx, db, dialect, m, dir, version)
}

// Returns the number of applied migrations.
func (ms MigrationSet) ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
return ms.ExecMaxContext(context.Background(), db, dialect, m, dir, max)
}

// Returns the number of applied migrations, but applies with an input context.
func (ms MigrationSet) ExecMaxContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
migrations, dbMap, err := ms.PlanMigration(db, dialect, m, dir, max)
if err != nil {
return 0, err
}
return ms.applyMigrations(dir, migrations, dbMap)
return ms.applyMigrations(ctx, dir, migrations, dbMap)
}

// Returns the number of applied migrations.
func (ms MigrationSet) ExecVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
return ms.ExecVersionContext(context.Background(), db, dialect, m, dir, version)
}

func (ms MigrationSet) ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
migrations, dbMap, err := ms.PlanMigrationToVersion(db, dialect, m, dir, version)
if err != nil {
return 0, err
}
return ms.applyMigrations(dir, migrations, dbMap)
return ms.applyMigrations(ctx, dir, migrations, dbMap)
}

// Applies the planned migrations and returns the number of applied migrations.
func (MigrationSet) applyMigrations(dir MigrationDirection, migrations []*PlannedMigration, dbMap *gorp.DbMap) (int, error) {
func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection, migrations []*PlannedMigration, dbMap *gorp.DbMap) (int, error) {
applied := 0
for _, migration := range migrations {
var executor SqlExecutor
var err error

if migration.DisableTransaction {
executor = dbMap
executor = dbMap.WithContext(ctx)
} else {
executor, err = dbMap.Begin()
e, err := dbMap.Begin()
if err != nil {
return applied, newTxError(migration, err)
}
executor = e.WithContext(ctx)
}

for _, stmt := range migration.Queries {
Expand Down
38 changes: 38 additions & 0 deletions migrate_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package migrate

import (
"context"
"database/sql"
"net/http"
"time"

"github.com/go-gorp/gorp/v3"
"github.com/gobuffalo/packr/v2"
Expand Down Expand Up @@ -757,3 +759,39 @@ func (s *SqliteMigrateSuite) TestGetMigrationDbMapWithDisableCreateTable(c *C) {
_, err := migSet.getMigrationDbMap(s.Db, "postgres")
c.Assert(err, IsNil)
}

func (s *SqliteMigrateSuite) TestContextTimeout(c *C) {
// This statement will run for a long time: 1,000,000 iterations of the fibonacci sequence
fibonacciLoopStmt := `WITH RECURSIVE
fibo (curr, next)
AS
( SELECT 1,1
UNION ALL
SELECT next, curr+next FROM fibo
LIMIT 1000000 )
SELECT group_concat(curr) FROM fibo;
`
migrations := &MemoryMigrationSource{
Migrations: []*Migration{
sqliteMigrations[0],
sqliteMigrations[1],
{
Id: "125",
Up: []string{fibonacciLoopStmt},
Down: []string{}, // Not important here
},
{
Id: "125",
Up: []string{"INSERT INTO people (id, first_name) VALUES (1, 'Test')", "SELECT fail"},
Down: []string{}, // Not important here
},
},
}

// Should never run the insert
ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Millisecond)
defer cancelFunc()
n, err := ExecContext(ctx, s.Db, "sqlite3", migrations, Up)
c.Assert(err, Not(IsNil))
c.Assert(n, Equals, 2)
}

0 comments on commit 61ee1bf

Please sign in to comment.