Skip to content
This repository has been archived by the owner on Aug 23, 2024. It is now read-only.

Commit

Permalink
Postgres - add x-migrations-table-quoted url query option (golang-mig…
Browse files Browse the repository at this point in the history
…rate#95)

By default, gomigrate quote migrations table name, if `x-migrations-table-quoted` is enabled, then you must to quote migrations table name manually, for instance `"gomigrate"."schema_migrations"`
  • Loading branch information
stephane-klein committed Mar 28, 2021
1 parent 511ae9f commit 345f235
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 6 deletions.
1 change: 1 addition & 0 deletions database/postgres/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, gomigrate quote migrations table name, if `x-migrations-table-quoted` is enabled, then you must to quote migrations table name manually, for instance `"gomigrate"."schema_migrations"` |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
Expand Down
30 changes: 25 additions & 5 deletions database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ var (

type Config struct {
MigrationsTable string
MigrationsTableQuoted bool
DatabaseName string
SchemaName string
StatementTimeout time.Duration
Expand Down Expand Up @@ -131,6 +132,17 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
}

migrationsTable := purl.Query().Get("x-migrations-table")
migrationsTableQuoted := false
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
migrationsTableQuoted, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
}
}
if (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"gomigrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled")
}

statementTimeoutString := purl.Query().Get("x-statement-timeout")
statementTimeout := 0
if statementTimeoutString != "" {
Expand Down Expand Up @@ -162,6 +174,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
MigrationsTableQuoted: migrationsTableQuoted,
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
MultiStatementEnabled: multiStatementEnabled,
MultiStatementMaxSize: multiStatementMaxSize,
Expand Down Expand Up @@ -312,13 +325,20 @@ func runesLastIndex(input []rune, target rune) int {
return -1
}

func (p *Postgres) quoteIdentifier(name string) string {
if p.config.MigrationsTableQuoted {
return name
}
return pq.QuoteIdentifier(name)
}

func (p *Postgres) SetVersion(version int, dirty bool) error {
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}

query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable)
query := `TRUNCATE ` + p.quoteIdentifier(p.config.MigrationsTable)
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
Expand All @@ -330,7 +350,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) +
query = `INSERT INTO ` + p.quoteIdentifier(p.config.MigrationsTable) +
` (version, dirty) VALUES ($1, $2)`
if _, err := tx.Exec(query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
Expand All @@ -348,7 +368,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
}

func (p *Postgres) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
query := `SELECT version, dirty FROM ` + p.quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
Expand Down Expand Up @@ -398,7 +418,7 @@ func (p *Postgres) Drop() (err error) {
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
query = `DROP TABLE IF EXISTS ` + p.quoteIdentifier(t) + ` CASCADE`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
Expand Down Expand Up @@ -443,7 +463,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
return nil
}

query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
query = `CREATE TABLE IF NOT EXISTS ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
Expand Down
40 changes: 39 additions & 1 deletion database/postgres/postgres_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@ import (
sqldriver "database/sql/driver"
"errors"
"fmt"
"github.com/golang-migrate/migrate/v4"
"io"
"log"
"strconv"
"strings"
"sync"
"testing"

"github.com/golang-migrate/migrate/v4"

"github.com/dhui/dktest"

"github.com/golang-migrate/migrate/v4/database"
Expand Down Expand Up @@ -683,3 +684,40 @@ func Test_computeLineFromPos(t *testing.T) {
}

}

func Test_quoteIdentifier(t *testing.T) {
testcases := []struct {
migrationsTableQuoted bool
migrationsTable string
want string
}{
{
false,
"schema_name.table_name",
"\"schema_name.table_name\"",
},
{
false,
"table_name",
"\"table_name\"",
},
{
true,
"\"schema_name\".\"table.name\"",
"\"schema_name\".\"table.name\"",
},
}
p := &Postgres{
config: &Config{
MigrationsTableQuoted: false,
},
}

for _, tc := range testcases {
p.config.MigrationsTableQuoted = tc.migrationsTableQuoted
got := p.quoteIdentifier(tc.migrationsTable)
if tc.want != got {
t.Fatalf("expected %s but got %s", tc.want, got)
}
}
}

0 comments on commit 345f235

Please sign in to comment.