Skip to content

Commit

Permalink
add ability to specify limits and the target version to migrate to
Browse files Browse the repository at this point in the history
  • Loading branch information
TECHNOFAB11 committed Jan 7, 2022
1 parent 5b60f68 commit a55233c
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 36 deletions.
12 changes: 11 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ func NewApp() *cli.App {
},
},
Action: action(func(db *dbmate.DB, c *cli.Context) error {
db.TargetVersion = c.Args().First()
db.Verbose = c.Bool("verbose")
return db.CreateAndMigrate()
}),
Expand All @@ -129,7 +130,7 @@ func NewApp() *cli.App {
},
{
Name: "migrate",
Usage: "Migrate to the latest version",
Usage: "Migrate to the specified or latest version",
Flags: []cli.Flag{
&cli.BoolFlag{
Name: "verbose",
Expand All @@ -139,6 +140,7 @@ func NewApp() *cli.App {
},
},
Action: action(func(db *dbmate.DB, c *cli.Context) error {
db.TargetVersion = c.Args().First()
db.Verbose = c.Bool("verbose")
return db.Migrate()
}),
Expand All @@ -154,8 +156,16 @@ func NewApp() *cli.App {
EnvVars: []string{"DBMATE_VERBOSE"},
Usage: "print the result of each statement execution",
},
&cli.IntFlag{
Name: "limit",
Aliases: []string{"l"},
Usage: "Limits the amount of rollbacks (defaults to 1 if no target version is specified)",
Value: -1,
},
},
Action: action(func(db *dbmate.DB, c *cli.Context) error {
db.TargetVersion = c.Args().First()
db.Limit = c.Int("limit")
db.Verbose = c.Bool("verbose")
return db.Rollback()
}),
Expand Down
107 changes: 72 additions & 35 deletions pkg/dbmate/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ type DB struct {
WaitBefore bool
WaitInterval time.Duration
WaitTimeout time.Duration
Limit int
TargetVersion string
Log io.Writer
}

Expand All @@ -64,6 +66,8 @@ func New(databaseURL *url.URL) *DB {
WaitBefore: false,
WaitInterval: DefaultWaitInterval,
WaitTimeout: DefaultWaitTimeout,
Limit: -1,
TargetVersion: "",
Log: os.Stdout,
}
}
Expand Down Expand Up @@ -336,14 +340,14 @@ func (db *DB) migrate(drv Driver) error {
}
defer dbutil.MustClose(sqlDB)

applied, err := drv.SelectMigrations(sqlDB, -1)
applied, err := drv.SelectMigrations(sqlDB, db.Limit)
if err != nil {
return err
}

for _, filename := range files {
ver := migrationVersion(filename)
if ok := applied[ver]; ok {
if ok := applied[ver]; ok && ver != db.TargetVersion {
// migration already applied
continue
}
Expand Down Expand Up @@ -379,6 +383,11 @@ func (db *DB) migrate(drv Driver) error {
if err != nil {
return err
}

if ver == db.TargetVersion {
fmt.Fprintf(db.Log, "Reached target version %s\n", ver)
break
}
}

// automatically update schema file, silence errors
Expand Down Expand Up @@ -469,55 +478,83 @@ func (db *DB) Rollback() error {
}
defer dbutil.MustClose(sqlDB)

applied, err := drv.SelectMigrations(sqlDB, 1)
limit := db.Limit
// default limit is -1, if we don't specify a version it should only rollback one version, not all
if limit <= 0 && db.TargetVersion == "" {
limit = 1
}

applied, err := drv.SelectMigrations(sqlDB, limit)
if err != nil {
return err
}

// grab most recent applied migration (applied has len=1)
latest := ""
for ver := range applied {
latest = ver
}
if latest == "" {
return fmt.Errorf("can't rollback: no migrations have been applied")
if len(applied) == 0 {
return fmt.Errorf("can't rollback, no migrations found")
}

filename, err := findMigrationFile(db.MigrationsDir, latest)
if err != nil {
return err
var versions []string
for v := range applied {
versions = append(versions, v)
}

fmt.Fprintf(db.Log, "Rolling back: %s\n", filename)
// new → old
sort.Sort(sort.Reverse(sort.StringSlice(versions)))

_, down, err := parseMigration(filepath.Join(db.MigrationsDir, filename))
if err != nil {
return err
if db.TargetVersion != "" {
cache := map[string]bool{}
found := false

// latest version comes first, so take every version until the version matches
for _, ver := range versions {
if ver == db.TargetVersion {
found = true
break
}
cache[ver] = true
}
if !found {
return fmt.Errorf("target version not found")
}
applied = cache
}

execMigration := func(tx dbutil.Transaction) error {
// rollback migration
result, err := tx.Exec(down.Contents)
for version := range applied {
filename, err := findMigrationFile(db.MigrationsDir, version)
if err != nil {
return err
} else if db.Verbose {
db.printVerbose(result)
}

// remove migration record
return drv.DeleteMigration(tx, latest)
}
fmt.Fprintf(db.Log, "Rolling back: %s\n", filename)
_, down, err := parseMigration(filepath.Join(db.MigrationsDir, filename))
if err != nil {
return err
}

if down.Options.Transaction() {
// begin transaction
err = doTransaction(sqlDB, execMigration)
} else {
// run outside of transaction
err = execMigration(sqlDB)
}
execMigration := func(tx dbutil.Transaction) error {
// rollback migration
result, err := tx.Exec(down.Contents)
if err != nil {
return err
} else if db.Verbose {
db.printVerbose(result)
}

if err != nil {
return err
// remove migration record
return drv.DeleteMigration(tx, version)
}

if down.Options.Transaction() {
// begin transaction
err = doTransaction(sqlDB, execMigration)
} else {
// run outside of transaction
err = execMigration(sqlDB)
}

if err != nil {
return err
}
}

// automatically update schema file, silence errors
Expand Down Expand Up @@ -582,7 +619,7 @@ func (db *DB) CheckMigrationsStatus(drv Driver) ([]StatusResult, error) {
}
defer dbutil.MustClose(sqlDB)

applied, err := drv.SelectMigrations(sqlDB, -1)
applied, err := drv.SelectMigrations(sqlDB, db.Limit)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit a55233c

Please sign in to comment.