diff --git a/file_migrator.go b/file_migrator.go index d4f0ec3e..271e71a6 100644 --- a/file_migrator.go +++ b/file_migrator.go @@ -81,7 +81,15 @@ func (fm *FileMigrator) findMigrations(runner func(mf Migration, tx *Connection) Type: match.Type, Runner: runner, } - fm.Migrations[mf.Direction] = append(fm.Migrations[mf.Direction], mf) + switch mf.Direction { + case "up": + fm.UpMigrations.Migrations = append(fm.UpMigrations.Migrations, mf) + case "down": + fm.DownMigrations.Migrations = append(fm.DownMigrations.Migrations, mf) + default: + // the regex only matches `(up|down)` for direction, so a panic here is appropriate + panic("got unknown migration direction " + mf.Direction) + } } return nil }) diff --git a/migration_box.go b/migration_box.go index 1e7d97ab..0385cc1d 100644 --- a/migration_box.go +++ b/migration_box.go @@ -75,7 +75,15 @@ func (fm *MigrationBox) findMigrations(runner func(f packd.File) func(mf Migrati Type: match.Type, Runner: runner(f), } - fm.Migrations[mf.Direction] = append(fm.Migrations[mf.Direction], mf) + switch mf.Direction { + case "up": + fm.UpMigrations.Migrations = append(fm.UpMigrations.Migrations, mf) + case "down": + fm.DownMigrations.Migrations = append(fm.DownMigrations.Migrations, mf) + default: + // the regex only matches `(up|down)` for direction, so a panic here is appropriate + panic("got unknown migration direction " + mf.Direction) + } return nil }) } diff --git a/migration_box_test.go b/migration_box_test.go index e094077b..7a9ada42 100644 --- a/migration_box_test.go +++ b/migration_box_test.go @@ -32,11 +32,11 @@ func Test_MigrationBox(t *testing.T) { b, err := NewMigrationBox(packr.New("./testdata/migrations/multiple", "./testdata/migrations/multiple"), PDB) r.NoError(err) - r.Equal(4, len(b.Migrations["up"])) - r.Equal("mysql", b.Migrations["up"][0].DBType) - r.Equal("postgres", b.Migrations["up"][1].DBType) - r.Equal("sqlite3", b.Migrations["up"][2].DBType) - r.Equal("all", b.Migrations["up"][3].DBType) + r.Equal(4, len(b.UpMigrations.Migrations)) + r.Equal("mysql", b.UpMigrations.Migrations[0].DBType) + r.Equal("postgres", b.UpMigrations.Migrations[1].DBType) + r.Equal("sqlite3", b.UpMigrations.Migrations[2].DBType) + r.Equal("all", b.UpMigrations.Migrations[3].DBType) }) t.Run("ignores clutter files", func(t *testing.T) { @@ -45,7 +45,7 @@ func Test_MigrationBox(t *testing.T) { b, err := NewMigrationBox(packr.New("./testdata/migrations/cluttered", "./testdata/migrations/cluttered"), PDB) r.NoError(err) - r.Equal(1, len(b.Migrations["up"])) + r.Equal(1, len(b.UpMigrations.Migrations)) r.Equal(1, len(*logs)) r.Equal(logging.Warn, (*logs)[0].lvl) r.Contains((*logs)[0].s, "ignoring file") @@ -58,7 +58,7 @@ func Test_MigrationBox(t *testing.T) { b, err := NewMigrationBox(packr.New("./testdata/migrations/unsupported_dialect", "./testdata/migrations/unsupported_dialect"), PDB) r.NoError(err) - r.Equal(0, len(b.Migrations["up"])) + r.Equal(0, len(b.UpMigrations.Migrations)) r.Equal(1, len(*logs)) r.Equal(logging.Warn, (*logs)[0].lvl) r.Contains((*logs)[0].s, "ignoring migration") diff --git a/migration_info.go b/migration_info.go index 3e6d5ce5..0fa26b14 100644 --- a/migration_info.go +++ b/migration_info.go @@ -36,14 +36,6 @@ func (mfs Migrations) Len() int { return len(mfs) } -func (mfs Migrations) Less(i, j int) bool { - if mfs[i].Version == mfs[j].Version { - // force "all" to the back - return mfs[i].DBType != "all" - } - return mfs[i].Version < mfs[j].Version -} - func (mfs Migrations) Swap(i, j int) { mfs[i], mfs[j] = mfs[j], mfs[i] } @@ -57,3 +49,28 @@ func (mfs *Migrations) Filter(f func(mf Migration) bool) { } *mfs = vsf } + +type ( + UpMigrations struct { + Migrations + } + DownMigrations struct { + Migrations + } +) + +func (mfs UpMigrations) Less(i, j int) bool { + if mfs.Migrations[i].Version == mfs.Migrations[j].Version { + // force "all" to the back + return mfs.Migrations[i].DBType != "all" + } + return mfs.Migrations[i].Version < mfs.Migrations[j].Version +} + +func (mfs DownMigrations) Less(i, j int) bool { + if mfs.Migrations[i].Version == mfs.Migrations[j].Version { + // force "all" to the back + return mfs.Migrations[i].DBType != "all" + } + return mfs.Migrations[i].Version > mfs.Migrations[j].Version +} diff --git a/migration_info_test.go b/migration_info_test.go index f2174551..fd86c64b 100644 --- a/migration_info_test.go +++ b/migration_info_test.go @@ -8,43 +8,65 @@ import ( ) func TestSortingMigrations(t *testing.T) { - t.Run("case=enforces precedence for specific migrations", func(t *testing.T) { - migrations := Migrations{ - { - Version: "1", - DBType: "all", - }, - { - Version: "1", - DBType: "postgres", - }, - { - Version: "2", - DBType: "cockroach", - }, - { - Version: "2", - DBType: "all", - }, - { - Version: "3", - DBType: "all", - }, - { - Version: "3", - DBType: "mysql", - }, + examples := Migrations{ + { + Version: "1", + DBType: "all", + }, + { + Version: "1", + DBType: "postgres", + }, + { + Version: "2", + DBType: "cockroach", + }, + { + Version: "2", + DBType: "all", + }, + { + Version: "3", + DBType: "all", + }, + { + Version: "3", + DBType: "mysql", + }, + } + + t.Run("case=enforces precedence for specific up migrations", func(t *testing.T) { + migrations := make(Migrations, len(examples)) + copy(migrations, examples) + + expectedOrder := Migrations{ + examples[1], + examples[0], + examples[2], + examples[3], + examples[5], + examples[4], } + + sort.Sort(UpMigrations{migrations}) + + assert.Equal(t, expectedOrder, migrations) + }) + + t.Run("case=enforces precedence for specific down migrations", func(t *testing.T) { + migrations := make(Migrations, len(examples)) + copy(migrations, examples) + expectedOrder := Migrations{ - migrations[1], - migrations[0], - migrations[2], - migrations[3], - migrations[5], - migrations[4], + examples[5], + examples[4], + examples[2], + examples[3], + examples[1], + examples[0], } - sort.Sort(migrations) + sort.Sort(DownMigrations{migrations}) assert.Equal(t, expectedOrder, migrations) }) diff --git a/migrator.go b/migrator.go index 6e3abdfc..39e19447 100644 --- a/migrator.go +++ b/migrator.go @@ -23,10 +23,6 @@ var mrx = regexp.MustCompile(`^(\d+)_([^.]+)(\.[a-z0-9]+)?\.(up|down)\.(sql|fizz func NewMigrator(c *Connection) Migrator { return Migrator{ Connection: c, - Migrations: map[string]Migrations{ - "up": {}, - "down": {}, - }, } } @@ -35,9 +31,10 @@ func NewMigrator(c *Connection) Migrator { // When building a new migration system, you should embed this // type into your migrator. type Migrator struct { - Connection *Connection - SchemaPath string - Migrations map[string]Migrations + Connection *Connection + SchemaPath string + UpMigrations UpMigrations + DownMigrations DownMigrations } func (m Migrator) migrationIsCompatible(d dialect, mi Migration) bool { @@ -53,10 +50,10 @@ func (m Migrator) UpLogOnly() error { c := m.Connection return m.exec(func() error { mtn := c.MigrationTableName() - mfs := m.Migrations["up"] + mfs := m.UpMigrations sort.Sort(mfs) return c.Transaction(func(tx *Connection) error { - for _, mi := range mfs { + for _, mi := range mfs.Migrations { if !m.migrationIsCompatible(c.Dialect, mi) { continue } @@ -89,12 +86,12 @@ func (m Migrator) UpTo(step int) (applied int, err error) { c := m.Connection err = m.exec(func() error { mtn := c.MigrationTableName() - mfs := m.Migrations["up"] + mfs := m.UpMigrations mfs.Filter(func(mf Migration) bool { return m.migrationIsCompatible(c.Dialect, mf) }) sort.Sort(mfs) - for _, mi := range mfs { + for _, mi := range mfs.Migrations { exists, err := c.Where("version = ?", mi.Version).Exists(mtn) if err != nil { return errors.Wrapf(err, "problem checking for migration version %s", mi.Version) @@ -139,20 +136,20 @@ func (m Migrator) Down(step int) error { if err != nil { return errors.Wrap(err, "migration down: unable count existing migration") } - mfs := m.Migrations["down"] + mfs := m.DownMigrations mfs.Filter(func(mf Migration) bool { return m.migrationIsCompatible(c.Dialect, mf) }) - sort.Sort(sort.Reverse(mfs)) + sort.Sort(mfs) // skip all ran migration - if len(mfs) > count { - mfs = mfs[len(mfs)-count:] + if len(mfs.Migrations) > count { + mfs.Migrations = mfs.Migrations[len(mfs.Migrations)-count:] } // run only required steps - if step > 0 && len(mfs) >= step { - mfs = mfs[:step] + if step > 0 && len(mfs.Migrations) >= step { + mfs.Migrations = mfs.Migrations[:step] } - for _, mi := range mfs { + for _, mi := range mfs.Migrations { exists, err := c.Where("version = ?", mi.Version).Exists(mtn) if err != nil { return errors.Wrapf(err, "problem checking for migration version %s", mi.Version) @@ -228,7 +225,7 @@ func (m Migrator) Status(out io.Writer) error { } w := tabwriter.NewWriter(out, 0, 0, 3, ' ', tabwriter.TabIndent) _, _ = fmt.Fprintln(w, "Version\tName\tStatus\t") - for _, mf := range m.Migrations["up"] { + for _, mf := range m.UpMigrations.Migrations { exists, err := m.Connection.Where("version = ?", mf.Version).Exists(m.Connection.MigrationTableName()) if err != nil { return errors.Wrapf(err, "problem with migration")