diff --git a/README.md b/README.md index dcbe7f9..3ddf73d 100644 --- a/README.md +++ b/README.md @@ -77,8 +77,23 @@ func main() { ``` Notes on examples above: + - Migrator creates/manages a table named `migrations` to keep track of the applied versions. However, if want to customize the table name `migrator.TableName("my_migrations")` can be passed to `migrator.New` function as an additional option. +### Logging + +By default, migrator prints applying/applied migration info to stdout, it that's enough for you, you can skip this section. + +If you need some special formatting or want to use a 3rd party logging library, this could be done by using `WithLogger` option as follows: + +```go +logger := migrator.WithLogger(LoggerFunc(func(str string, args ...interface{}) { + // Your code here +}))) +``` + +Then you only will need to pass the logger as an option to `migrator.New`. + ### Looking for more examples? Just examine the [migrator_test.go](migrator_test.go) file. diff --git a/migrator.go b/migrator.go index afa9d30..955909b 100644 --- a/migrator.go +++ b/migrator.go @@ -4,6 +4,8 @@ import ( "database/sql" "errors" "fmt" + "log" + "os" ) const defaultTableName = "migrations" @@ -11,6 +13,7 @@ const defaultTableName = "migrations" // Migrator is the migrator implementation type Migrator struct { tableName string + logger Logger migrations []interface{} } @@ -24,6 +27,26 @@ func TableName(tableName string) Option { } } +// Logger interface +type Logger interface { + Printf(string, ...interface{}) +} + +// LoggerFunc is a bridge between Logger and any third party logger +type LoggerFunc func(string, ...interface{}) + +// Printf implements Logger interface +func (f LoggerFunc) Printf(msg string, args ...interface{}) { + f(msg, args...) +} + +// WithLogger creates an option to allow overriding the stdout logging +func WithLogger(logger Logger) Option { + return func(m *Migrator) { + m.logger = logger + } +} + // Migrations creates an option with provided migrations func Migrations(migrations ...interface{}) Option { return func(m *Migrator) { @@ -34,6 +57,7 @@ func Migrations(migrations ...interface{}) Option { // New creates a new migrator instance func New(opts ...Option) (*Migrator, error) { m := &Migrator{ + logger: log.New(os.Stdout, "migrator: ", 0), tableName: defaultTableName, } for _, opt := range opts { @@ -83,13 +107,13 @@ func (m *Migrator) Migrate(db *sql.DB) error { // plan migrations for idx, migration := range m.migrations[count:len(m.migrations)] { insertVersion := fmt.Sprintf("INSERT INTO %s (id, version) VALUES (%d, '%s')", m.tableName, idx+count, migration.(fmt.Stringer).String()) - switch m := migration.(type) { + switch mig := migration.(type) { case *Migration: - if err := migrate(db, insertVersion, m); err != nil { + if err := migrate(db, m.logger, insertVersion, mig); err != nil { return fmt.Errorf("migrator: error while running migrations: %v", err) } case *MigrationNoTx: - if err := migrateNoTx(db, insertVersion, m); err != nil { + if err := migrateNoTx(db, m.logger, insertVersion, mig); err != nil { return fmt.Errorf("migrator: error while running migrations: %v", err) } } @@ -149,7 +173,7 @@ func (m *MigrationNoTx) String() string { return m.Name } -func migrate(db *sql.DB, insertVersion string, migration *Migration) error { +func migrate(db *sql.DB, logger Logger, insertVersion string, migration *Migration) error { tx, err := db.Begin() if err != nil { return err @@ -163,27 +187,27 @@ func migrate(db *sql.DB, insertVersion string, migration *Migration) error { } err = tx.Commit() }() - fmt.Println(fmt.Sprintf("migrator: applying migration named '%s'...", migration.Name)) + logger.Printf("applying migration named '%s'...", migration.Name) if err = migration.Func(tx); err != nil { return fmt.Errorf("error executing golang migration: %s", err) } if _, err = tx.Exec(insertVersion); err != nil { return fmt.Errorf("error updating migration versions: %s", err) } - fmt.Println(fmt.Sprintf("migrator: applied migration named '%s'", migration.Name)) + logger.Printf("applied migration named '%s'", migration.Name) return err } -func migrateNoTx(db *sql.DB, insertVersion string, migration *MigrationNoTx) error { - fmt.Println(fmt.Sprintf("migrator: applying no tx migration named '%s'...", migration.Name)) +func migrateNoTx(db *sql.DB, logger Logger, insertVersion string, migration *MigrationNoTx) error { + logger.Printf("applying no tx migration named '%s'...", migration.Name) if err := migration.Func(db); err != nil { return fmt.Errorf("error executing golang migration: %s", err) } if _, err := db.Exec(insertVersion); err != nil { return fmt.Errorf("error updating migration versions: %s", err) } - fmt.Println(fmt.Sprintf("migrator: applied no tx migration named '%s'...", migration.Name)) + logger.Printf("applied no tx migration named '%s'", migration.Name) return nil } diff --git a/migrator_test.go b/migrator_test.go index a77d863..7879284 100644 --- a/migrator_test.go +++ b/migrator_test.go @@ -5,6 +5,7 @@ package migrator import ( "database/sql" "fmt" + "log" "os" "strings" "testing" @@ -167,7 +168,7 @@ func TestBadMigrate(t *testing.T) { if err != nil { t.Fatal(err) } - if err := migrate(db, "BAD INSERT VERSION", &Migration{Name: "bad insert version", Func: func(tx *sql.Tx) error { + if err := migrate(db, log.New(os.Stdout, "migrator: ", 0), "BAD INSERT VERSION", &Migration{Name: "bad insert version", Func: func(tx *sql.Tx) error { return nil }}); err == nil { t.Fatal("BAD INSERT VERSION should fail!") @@ -179,7 +180,7 @@ func TestBadMigrateNoTx(t *testing.T) { if err != nil { t.Fatal(err) } - if err := migrateNoTx(db, "BAD INSERT VERSION", &MigrationNoTx{Name: "bad migrate no tx", Func: func(db *sql.DB) error { + if err := migrateNoTx(db, log.New(os.Stdout, "migrator: ", 0), "BAD INSERT VERSION", &MigrationNoTx{Name: "bad migrate no tx", Func: func(db *sql.DB) error { return nil }}); err == nil { t.Fatal("BAD INSERT VERSION should fail!")