Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

logger: implement logger option #31

Merged
merged 1 commit into from
Apr 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,23 @@ func main() {
```

Notes on examples above:
- Migrator creates/manages a table named `migrations` to keep track of the applied versions. However, if want to customize the table name `migrator.TableName("my_migrations")` can be passed to `migrator.New` function as an additional option.

- Migrator creates/manages a table named `migrations` to keep track of the applied versions. However, if you want to customize the table name `migrator.TableName("my_migrations")` can be passed to `migrator.New` function as an additional option.

### Logging

By default, migrator prints applying/applied migration info to stdout.
If that's enough for you, you can skip this section.

If you need some special formatting or want to use a 3rd party logging library, this could be done by using `WithLogger` option as follows:

```go
logger := migrator.WithLogger(migrator.LoggerFunc(func(msg string, args ...interface{}) {
// Your code here
})))
```

Then you will only need to pass the logger as an option to `migrator.New`.

### Looking for more examples?

Expand Down
42 changes: 33 additions & 9 deletions migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@ import (
"database/sql"
"errors"
"fmt"
"log"
"os"
)

const defaultTableName = "migrations"

// Migrator is the migrator implementation
type Migrator struct {
tableName string
logger Logger
migrations []interface{}
}

Expand All @@ -24,6 +27,26 @@ func TableName(tableName string) Option {
}
}

// Logger interface
type Logger interface {
Printf(string, ...interface{})
}

// LoggerFunc is a bridge between Logger and any third party logger
type LoggerFunc func(string, ...interface{})

// Printf implements Logger interface
func (f LoggerFunc) Printf(msg string, args ...interface{}) {
f(msg, args...)
}

// WithLogger creates an option to allow overriding the stdout logging
func WithLogger(logger Logger) Option {
return func(m *Migrator) {
m.logger = logger
}
}

// Migrations creates an option with provided migrations
func Migrations(migrations ...interface{}) Option {
return func(m *Migrator) {
Expand All @@ -34,6 +57,7 @@ func Migrations(migrations ...interface{}) Option {
// New creates a new migrator instance
func New(opts ...Option) (*Migrator, error) {
m := &Migrator{
logger: log.New(os.Stdout, "migrator: ", 0),
tableName: defaultTableName,
}
for _, opt := range opts {
Expand Down Expand Up @@ -83,13 +107,13 @@ func (m *Migrator) Migrate(db *sql.DB) error {
// plan migrations
for idx, migration := range m.migrations[count:len(m.migrations)] {
insertVersion := fmt.Sprintf("INSERT INTO %s (id, version) VALUES (%d, '%s')", m.tableName, idx+count, migration.(fmt.Stringer).String())
switch m := migration.(type) {
switch mig := migration.(type) {
case *Migration:
if err := migrate(db, insertVersion, m); err != nil {
if err := migrate(db, m.logger, insertVersion, mig); err != nil {
return fmt.Errorf("migrator: error while running migrations: %v", err)
}
case *MigrationNoTx:
if err := migrateNoTx(db, insertVersion, m); err != nil {
if err := migrateNoTx(db, m.logger, insertVersion, mig); err != nil {
return fmt.Errorf("migrator: error while running migrations: %v", err)
}
}
Expand Down Expand Up @@ -149,7 +173,7 @@ func (m *MigrationNoTx) String() string {
return m.Name
}

func migrate(db *sql.DB, insertVersion string, migration *Migration) error {
func migrate(db *sql.DB, logger Logger, insertVersion string, migration *Migration) error {
tx, err := db.Begin()
if err != nil {
return err
Expand All @@ -163,27 +187,27 @@ func migrate(db *sql.DB, insertVersion string, migration *Migration) error {
}
err = tx.Commit()
}()
fmt.Println(fmt.Sprintf("migrator: applying migration named '%s'...", migration.Name))
logger.Printf("applying migration named '%s'...", migration.Name)
if err = migration.Func(tx); err != nil {
return fmt.Errorf("error executing golang migration: %s", err)
}
if _, err = tx.Exec(insertVersion); err != nil {
return fmt.Errorf("error updating migration versions: %s", err)
}
fmt.Println(fmt.Sprintf("migrator: applied migration named '%s'", migration.Name))
logger.Printf("applied migration named '%s'", migration.Name)

return err
}

func migrateNoTx(db *sql.DB, insertVersion string, migration *MigrationNoTx) error {
fmt.Println(fmt.Sprintf("migrator: applying no tx migration named '%s'...", migration.Name))
func migrateNoTx(db *sql.DB, logger Logger, insertVersion string, migration *MigrationNoTx) error {
logger.Printf("applying no tx migration named '%s'...", migration.Name)
if err := migration.Func(db); err != nil {
return fmt.Errorf("error executing golang migration: %s", err)
}
if _, err := db.Exec(insertVersion); err != nil {
return fmt.Errorf("error updating migration versions: %s", err)
}
fmt.Println(fmt.Sprintf("migrator: applied no tx migration named '%s'...", migration.Name))
logger.Printf("applied no tx migration named '%s'", migration.Name)

return nil
}
5 changes: 3 additions & 2 deletions migrator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package migrator
import (
"database/sql"
"fmt"
"log"
"os"
"strings"
"testing"
Expand Down Expand Up @@ -167,7 +168,7 @@ func TestBadMigrate(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err := migrate(db, "BAD INSERT VERSION", &Migration{Name: "bad insert version", Func: func(tx *sql.Tx) error {
if err := migrate(db, log.New(os.Stdout, "migrator: ", 0), "BAD INSERT VERSION", &Migration{Name: "bad insert version", Func: func(tx *sql.Tx) error {
return nil
}}); err == nil {
t.Fatal("BAD INSERT VERSION should fail!")
Expand All @@ -179,7 +180,7 @@ func TestBadMigrateNoTx(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err := migrateNoTx(db, "BAD INSERT VERSION", &MigrationNoTx{Name: "bad migrate no tx", Func: func(db *sql.DB) error {
if err := migrateNoTx(db, log.New(os.Stdout, "migrator: ", 0), "BAD INSERT VERSION", &MigrationNoTx{Name: "bad migrate no tx", Func: func(db *sql.DB) error {
return nil
}}); err == nil {
t.Fatal("BAD INSERT VERSION should fail!")
Expand Down