Skip to content

Commit 6890aa5

Browse files
Postgres - add x-migrations-table-quoted url query option (#95)
By default, gomigrate quote migrations table name, if `x-migrations-table-quoted` is enabled, then you must to quote migrations table name manually, for instance `"gomigrate"."schema_migrations"`
1 parent 511ae9f commit 6890aa5

File tree

3 files changed

+157
-10
lines changed

3 files changed

+157
-10
lines changed

database/postgres/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
| URL Query | WithInstance Config | Description |
66
|------------|---------------------|-------------|
77
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
8+
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, gomigrate quote migrations table name, if `x-migrations-table-quoted` is enabled, then you must to quote migrations table name manually, for instance `"gomigrate"."schema_migrations"` |
89
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
910
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
1011
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |

database/postgres/postgres.go

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"io"
1010
"io/ioutil"
1111
nurl "net/url"
12+
"regexp"
1213
"strconv"
1314
"strings"
1415
"time"
@@ -42,10 +43,11 @@ var (
4243

4344
type Config struct {
4445
MigrationsTable string
46+
MigrationsTableQuoted bool
47+
MultiStatementEnabled bool
4548
DatabaseName string
4649
SchemaName string
4750
StatementTimeout time.Duration
48-
MultiStatementEnabled bool
4951
MultiStatementMaxSize int
5052
}
5153

@@ -99,7 +101,6 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
99101
if len(config.MigrationsTable) == 0 {
100102
config.MigrationsTable = DefaultMigrationsTable
101103
}
102-
103104
conn, err := instance.Conn(context.Background())
104105

105106
if err != nil {
@@ -131,6 +132,17 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
131132
}
132133

133134
migrationsTable := purl.Query().Get("x-migrations-table")
135+
migrationsTableQuoted := false
136+
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
137+
migrationsTableQuoted, err = strconv.ParseBool(s)
138+
if err != nil {
139+
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
140+
}
141+
}
142+
if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
143+
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"gomigrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
144+
}
145+
134146
statementTimeoutString := purl.Query().Get("x-statement-timeout")
135147
statementTimeout := 0
136148
if statementTimeoutString != "" {
@@ -162,6 +174,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
162174
px, err := WithInstance(db, &Config{
163175
DatabaseName: purl.Path,
164176
MigrationsTable: migrationsTable,
177+
MigrationsTableQuoted: migrationsTableQuoted,
165178
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
166179
MultiStatementEnabled: multiStatementEnabled,
167180
MultiStatementMaxSize: multiStatementMaxSize,
@@ -312,13 +325,20 @@ func runesLastIndex(input []rune, target rune) int {
312325
return -1
313326
}
314327

328+
func (p *Postgres) quoteIdentifier(name string) string {
329+
if p.config.MigrationsTableQuoted {
330+
return name
331+
}
332+
return pq.QuoteIdentifier(name)
333+
}
334+
315335
func (p *Postgres) SetVersion(version int, dirty bool) error {
316336
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
317337
if err != nil {
318338
return &database.Error{OrigErr: err, Err: "transaction start failed"}
319339
}
320340

321-
query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable)
341+
query := `TRUNCATE ` + p.quoteIdentifier(p.config.MigrationsTable)
322342
if _, err := tx.Exec(query); err != nil {
323343
if errRollback := tx.Rollback(); errRollback != nil {
324344
err = multierror.Append(err, errRollback)
@@ -330,7 +350,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
330350
// empty schema version for failed down migration on the first migration
331351
// See: https://github.com/golang-migrate/migrate/issues/330
332352
if version >= 0 || (version == database.NilVersion && dirty) {
333-
query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) +
353+
query = `INSERT INTO ` + p.quoteIdentifier(p.config.MigrationsTable) +
334354
` (version, dirty) VALUES ($1, $2)`
335355
if _, err := tx.Exec(query, version, dirty); err != nil {
336356
if errRollback := tx.Rollback(); errRollback != nil {
@@ -348,7 +368,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
348368
}
349369

350370
func (p *Postgres) Version() (version int, dirty bool, err error) {
351-
query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
371+
query := `SELECT version, dirty FROM ` + p.quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
352372
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
353373
switch {
354374
case err == sql.ErrNoRows:
@@ -398,7 +418,7 @@ func (p *Postgres) Drop() (err error) {
398418
if len(tableNames) > 0 {
399419
// delete one by one ...
400420
for _, t := range tableNames {
401-
query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE`
421+
query = `DROP TABLE IF EXISTS ` + p.quoteIdentifier(t) + ` CASCADE`
402422
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
403423
return &database.Error{OrigErr: err, Query: []byte(query)}
404424
}
@@ -431,8 +451,25 @@ func (p *Postgres) ensureVersionTable() (err error) {
431451
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
432452
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
433453
var count int
434-
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
435-
row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable)
454+
var row *sql.Row
455+
tableName := []byte(p.config.MigrationsTable)
456+
schemaName := []byte("")
457+
if p.config.MigrationsTableQuoted {
458+
re := regexp.MustCompile(`"(.*?)"`)
459+
result := re.FindAllSubmatch([]byte(p.config.MigrationsTable), -1)
460+
tableName = result[len(result)-1][1]
461+
if len(result) > 1 {
462+
schemaName = result[0][1]
463+
}
464+
}
465+
var query string
466+
if len(schemaName) > 0 {
467+
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2 LIMIT 1`
468+
row = p.conn.QueryRowContext(context.Background(), query, tableName, schemaName)
469+
} else {
470+
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
471+
row = p.conn.QueryRowContext(context.Background(), query, tableName)
472+
}
436473

437474
err = row.Scan(&count)
438475
if err != nil {
@@ -443,7 +480,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
443480
return nil
444481
}
445482

446-
query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
483+
query = `CREATE TABLE IF NOT EXISTS ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
447484
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
448485
return &database.Error{OrigErr: err, Query: []byte(query)}
449486
}

database/postgres/postgres_test.go

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@ import (
88
sqldriver "database/sql/driver"
99
"errors"
1010
"fmt"
11-
"github.com/golang-migrate/migrate/v4"
1211
"io"
1312
"log"
1413
"strconv"
1514
"strings"
1615
"sync"
1716
"testing"
1817

18+
"github.com/golang-migrate/migrate/v4"
19+
1920
"github.com/dhui/dktest"
2021

2122
"github.com/golang-migrate/migrate/v4/database"
@@ -318,6 +319,65 @@ func TestWithSchema(t *testing.T) {
318319
})
319320
}
320321

322+
func TestMigrationTableOption(t *testing.T) {
323+
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
324+
ip, port, err := c.FirstPort()
325+
if err != nil {
326+
t.Fatal(err)
327+
}
328+
329+
addr := pgConnectionString(ip, port)
330+
p := &Postgres{}
331+
d, _ := p.Open(addr)
332+
defer func() {
333+
if err := d.Close(); err != nil {
334+
t.Fatal(err)
335+
}
336+
}()
337+
338+
// create gomigrate schema
339+
if err := d.Run(strings.NewReader("CREATE SCHEMA gomigrate AUTHORIZATION postgres")); err != nil {
340+
t.Fatal(err)
341+
}
342+
343+
// bad unquoted x-migrations-table parameter
344+
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=gomigrate.schema_migrations&x-migrations-table-quoted=1",
345+
pgPassword, ip, port))
346+
if err == nil {
347+
t.Fatal("expected x-migrations-table must be quoted...")
348+
}
349+
350+
// good quoted x-migrations-table parameter
351+
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"gomigrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
352+
pgPassword, ip, port))
353+
if err != nil {
354+
t.Fatal(err)
355+
}
356+
357+
// make sure gomigrate.schema_migrations table exists
358+
var exists bool
359+
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'gomigrate')").Scan(&exists); err != nil {
360+
t.Fatal(err)
361+
}
362+
if !exists {
363+
t.Fatalf("expected table gomigrate.schema_migrations to exist")
364+
}
365+
366+
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=gomigrate.schema_migrations",
367+
pgPassword, ip, port))
368+
if err != nil {
369+
t.Fatal(err)
370+
}
371+
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'gomigrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
372+
t.Fatal(err)
373+
}
374+
if !exists {
375+
t.Fatalf("expected table 'gomigrate.schema_migrations' to exist")
376+
}
377+
378+
})
379+
}
380+
321381
func TestFailToCreateTableWithoutPermissions(t *testing.T) {
322382
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
323383
ip, port, err := c.FirstPort()
@@ -373,6 +433,18 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) {
373433
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
374434
t.Fatal(e)
375435
}
436+
437+
// re-connect using that x-migrations-table and x-migrations-table-quoted
438+
d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
439+
pgPassword, ip, port))
440+
441+
if !errors.As(err, &e) || err == nil {
442+
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
443+
}
444+
445+
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
446+
t.Fatal(e)
447+
}
376448
})
377449
}
378450

@@ -683,3 +755,40 @@ func Test_computeLineFromPos(t *testing.T) {
683755
}
684756

685757
}
758+
759+
func Test_quoteIdentifier(t *testing.T) {
760+
testcases := []struct {
761+
migrationsTableQuoted bool
762+
migrationsTable string
763+
want string
764+
}{
765+
{
766+
false,
767+
"schema_name.table_name",
768+
"\"schema_name.table_name\"",
769+
},
770+
{
771+
false,
772+
"table_name",
773+
"\"table_name\"",
774+
},
775+
{
776+
true,
777+
"\"schema_name\".\"table.name\"",
778+
"\"schema_name\".\"table.name\"",
779+
},
780+
}
781+
p := &Postgres{
782+
config: &Config{
783+
MigrationsTableQuoted: false,
784+
},
785+
}
786+
787+
for _, tc := range testcases {
788+
p.config.MigrationsTableQuoted = tc.migrationsTableQuoted
789+
got := p.quoteIdentifier(tc.migrationsTable)
790+
if tc.want != got {
791+
t.Fatalf("expected %s but got %s", tc.want, got)
792+
}
793+
}
794+
}

0 commit comments

Comments
 (0)