99 "io"
1010 "io/ioutil"
1111 nurl "net/url"
12+ "regexp"
1213 "strconv"
1314 "strings"
1415 "time"
@@ -42,10 +43,11 @@ var (
4243
4344type Config struct {
4445 MigrationsTable string
46+ MigrationsTableQuoted bool
47+ MultiStatementEnabled bool
4548 DatabaseName string
4649 SchemaName string
4750 StatementTimeout time.Duration
48- MultiStatementEnabled bool
4951 MultiStatementMaxSize int
5052}
5153
@@ -99,7 +101,6 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
99101 if len (config .MigrationsTable ) == 0 {
100102 config .MigrationsTable = DefaultMigrationsTable
101103 }
102-
103104 conn , err := instance .Conn (context .Background ())
104105
105106 if err != nil {
@@ -131,6 +132,17 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
131132 }
132133
133134 migrationsTable := purl .Query ().Get ("x-migrations-table" )
135+ migrationsTableQuoted := false
136+ if s := purl .Query ().Get ("x-migrations-table-quoted" ); len (s ) > 0 {
137+ migrationsTableQuoted , err = strconv .ParseBool (s )
138+ if err != nil {
139+ return nil , fmt .Errorf ("Unable to parse option x-migrations-table-quoted: %w" , err )
140+ }
141+ }
142+ if (len (migrationsTable ) > 0 ) && (migrationsTableQuoted ) && ((migrationsTable [0 ] != '"' ) || (migrationsTable [len (migrationsTable )- 1 ] != '"' )) {
143+ return nil , fmt .Errorf ("x-migrations-table must be quoted (for instance '\" gomigrate\" .\" schema_migrations\" ') when x-migrations-table-quoted is enabled, current value is: %s" , migrationsTable )
144+ }
145+
134146 statementTimeoutString := purl .Query ().Get ("x-statement-timeout" )
135147 statementTimeout := 0
136148 if statementTimeoutString != "" {
@@ -162,6 +174,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
162174 px , err := WithInstance (db , & Config {
163175 DatabaseName : purl .Path ,
164176 MigrationsTable : migrationsTable ,
177+ MigrationsTableQuoted : migrationsTableQuoted ,
165178 StatementTimeout : time .Duration (statementTimeout ) * time .Millisecond ,
166179 MultiStatementEnabled : multiStatementEnabled ,
167180 MultiStatementMaxSize : multiStatementMaxSize ,
@@ -312,13 +325,20 @@ func runesLastIndex(input []rune, target rune) int {
312325 return - 1
313326}
314327
328+ func (p * Postgres ) quoteIdentifier (name string ) string {
329+ if p .config .MigrationsTableQuoted {
330+ return name
331+ }
332+ return pq .QuoteIdentifier (name )
333+ }
334+
315335func (p * Postgres ) SetVersion (version int , dirty bool ) error {
316336 tx , err := p .conn .BeginTx (context .Background (), & sql.TxOptions {})
317337 if err != nil {
318338 return & database.Error {OrigErr : err , Err : "transaction start failed" }
319339 }
320340
321- query := `TRUNCATE ` + pq . QuoteIdentifier (p .config .MigrationsTable )
341+ query := `TRUNCATE ` + p . quoteIdentifier (p .config .MigrationsTable )
322342 if _ , err := tx .Exec (query ); err != nil {
323343 if errRollback := tx .Rollback (); errRollback != nil {
324344 err = multierror .Append (err , errRollback )
@@ -330,7 +350,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
330350 // empty schema version for failed down migration on the first migration
331351 // See: https://github.com/golang-migrate/migrate/issues/330
332352 if version >= 0 || (version == database .NilVersion && dirty ) {
333- query = `INSERT INTO ` + pq . QuoteIdentifier (p .config .MigrationsTable ) +
353+ query = `INSERT INTO ` + p . quoteIdentifier (p .config .MigrationsTable ) +
334354 ` (version, dirty) VALUES ($1, $2)`
335355 if _ , err := tx .Exec (query , version , dirty ); err != nil {
336356 if errRollback := tx .Rollback (); errRollback != nil {
@@ -348,7 +368,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
348368}
349369
350370func (p * Postgres ) Version () (version int , dirty bool , err error ) {
351- query := `SELECT version, dirty FROM ` + pq . QuoteIdentifier (p .config .MigrationsTable ) + ` LIMIT 1`
371+ query := `SELECT version, dirty FROM ` + p . quoteIdentifier (p .config .MigrationsTable ) + ` LIMIT 1`
352372 err = p .conn .QueryRowContext (context .Background (), query ).Scan (& version , & dirty )
353373 switch {
354374 case err == sql .ErrNoRows :
@@ -398,7 +418,7 @@ func (p *Postgres) Drop() (err error) {
398418 if len (tableNames ) > 0 {
399419 // delete one by one ...
400420 for _ , t := range tableNames {
401- query = `DROP TABLE IF EXISTS ` + pq . QuoteIdentifier (t ) + ` CASCADE`
421+ query = `DROP TABLE IF EXISTS ` + p . quoteIdentifier (t ) + ` CASCADE`
402422 if _ , err := p .conn .ExecContext (context .Background (), query ); err != nil {
403423 return & database.Error {OrigErr : err , Query : []byte (query )}
404424 }
@@ -431,8 +451,25 @@ func (p *Postgres) ensureVersionTable() (err error) {
431451 // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
432452 // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
433453 var count int
434- query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
435- row := p .conn .QueryRowContext (context .Background (), query , p .config .MigrationsTable )
454+ var row * sql.Row
455+ tableName := []byte (p .config .MigrationsTable )
456+ schemaName := []byte ("" )
457+ if p .config .MigrationsTableQuoted {
458+ re := regexp .MustCompile (`"(.*?)"` )
459+ result := re .FindAllSubmatch ([]byte (p .config .MigrationsTable ), - 1 )
460+ tableName = result [len (result )- 1 ][1 ]
461+ if len (result ) > 1 {
462+ schemaName = result [0 ][1 ]
463+ }
464+ }
465+ var query string
466+ if len (schemaName ) > 0 {
467+ query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2 LIMIT 1`
468+ row = p .conn .QueryRowContext (context .Background (), query , tableName , schemaName )
469+ } else {
470+ query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
471+ row = p .conn .QueryRowContext (context .Background (), query , tableName )
472+ }
436473
437474 err = row .Scan (& count )
438475 if err != nil {
@@ -443,7 +480,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
443480 return nil
444481 }
445482
446- query = `CREATE TABLE IF NOT EXISTS ` + pq . QuoteIdentifier (p .config .MigrationsTable ) + ` (version bigint not null primary key, dirty boolean not null)`
483+ query = `CREATE TABLE IF NOT EXISTS ` + p . quoteIdentifier (p .config .MigrationsTable ) + ` (version bigint not null primary key, dirty boolean not null)`
447484 if _ , err = p .conn .ExecContext (context .Background (), query ); err != nil {
448485 return & database.Error {OrigErr : err , Query : []byte (query )}
449486 }
0 commit comments