Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions database/pgx/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
Expand Down
54 changes: 45 additions & 9 deletions database/pgx/pgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"io/ioutil"
nurl "net/url"
"regexp"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -46,6 +47,7 @@ type Config struct {
DatabaseName string
SchemaName string
StatementTimeout time.Duration
MigrationsTableQuoted bool
MultiStatementEnabled bool
MultiStatementMaxSize int
}
Expand Down Expand Up @@ -137,6 +139,17 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
}

migrationsTable := purl.Query().Get("x-migrations-table")
migrationsTableQuoted := false
if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 {
migrationsTableQuoted, err = strconv.ParseBool(s)
if err != nil {
return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err)
}
}
if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) {
return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable)
}

statementTimeoutString := purl.Query().Get("x-statement-timeout")
statementTimeout := 0
if statementTimeoutString != "" {
Expand Down Expand Up @@ -168,6 +181,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) {
px, err := WithInstance(db, &Config{
DatabaseName: purl.Path,
MigrationsTable: migrationsTable,
MigrationsTableQuoted: migrationsTableQuoted,
StatementTimeout: time.Duration(statementTimeout) * time.Millisecond,
MultiStatementEnabled: multiStatementEnabled,
MultiStatementMaxSize: multiStatementMaxSize,
Expand Down Expand Up @@ -321,7 +335,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
return &database.Error{OrigErr: err, Err: "transaction start failed"}
}

query := `TRUNCATE ` + quoteIdentifier(p.config.MigrationsTable)
query := `TRUNCATE ` + p.quoteIdentifier(p.config.MigrationsTable)
if _, err := tx.Exec(query); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
err = multierror.Append(err, errRollback)
Expand All @@ -333,7 +347,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
// empty schema version for failed down migration on the first migration
// See: https://github.com/golang-migrate/migrate/issues/330
if version >= 0 || (version == database.NilVersion && dirty) {
query = `INSERT INTO ` + quoteIdentifier(p.config.MigrationsTable) +
query = `INSERT INTO ` + p.quoteIdentifier(p.config.MigrationsTable) +
` (version, dirty) VALUES ($1, $2)`
if _, err := tx.Exec(query, version, dirty); err != nil {
if errRollback := tx.Rollback(); errRollback != nil {
Expand All @@ -351,7 +365,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error {
}

func (p *Postgres) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
query := `SELECT version, dirty FROM ` + p.quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1`
err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty)
switch {
case err == sql.ErrNoRows:
Expand Down Expand Up @@ -401,7 +415,7 @@ func (p *Postgres) Drop() (err error) {
if len(tableNames) > 0 {
// delete one by one ...
for _, t := range tableNames {
query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE`
query = `DROP TABLE IF EXISTS ` + p.quoteIdentifier(t) + ` CASCADE`
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
Expand Down Expand Up @@ -433,10 +447,29 @@ func (p *Postgres) ensureVersionTable() (err error) {
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable)
var row *sql.Row
tableName := p.config.MigrationsTable
schemaName := ""
if p.config.MigrationsTableQuoted {
re := regexp.MustCompile(`"(.*?)"`)
result := re.FindAllStringSubmatch(p.config.MigrationsTable, -1)
tableName = result[len(result)-1][1]
if len(result) == 2 {
schemaName = result[0][1]
} else if len(result) > 2 {
return fmt.Errorf("\"%s\" MigrationsTable contains too many dot characters", p.config.MigrationsTable)
}
}
var query string
if len(schemaName) > 0 {
query = `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2 LIMIT 1`
row = p.conn.QueryRowContext(context.Background(), query, tableName, schemaName)
} else {
query = `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
row = p.conn.QueryRowContext(context.Background(), query, tableName)
}

var count int
err = row.Scan(&count)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
Expand All @@ -446,7 +479,7 @@ func (p *Postgres) ensureVersionTable() (err error) {
return nil
}

query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
query = `CREATE TABLE IF NOT EXISTS ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
Expand All @@ -455,7 +488,10 @@ func (p *Postgres) ensureVersionTable() (err error) {
}

// Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611
func quoteIdentifier(name string) string {
func (p *Postgres) quoteIdentifier(name string) string {
if p.config.MigrationsTableQuoted {
return name
}
end := strings.IndexRune(name, 0)
if end > -1 {
name = name[:end]
Expand Down
119 changes: 119 additions & 0 deletions database/pgx/pgx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,74 @@ func TestWithSchema(t *testing.T) {
})
}

func TestMigrationTableOption(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

addr := pgConnectionString(ip, port)
p := &Postgres{}
d, _ := p.Open(addr)
defer func() {
if err := d.Close(); err != nil {
t.Fatal(err)
}
}()

// create migrate schema
if err := d.Run(strings.NewReader("CREATE SCHEMA migrate AUTHORIZATION postgres")); err != nil {
t.Fatal(err)
}

// bad unquoted x-migrations-table parameter
wantErr := "x-migrations-table must be quoted (for instance '\"migrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: migrate.schema_migrations"
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations&x-migrations-table-quoted=1",
pgPassword, ip, port))
if (err != nil) && (err.Error() != wantErr) {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}

// too many quoted x-migrations-table parameters
wantErr = "\"\"migrate\".\"schema_migrations\".\"toomany\"\" MigrationsTable contains too many dot characters"
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\".\"toomany\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if (err != nil) && (err.Error() != wantErr) {
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
}

// good quoted x-migrations-table parameter
d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"migrate\".\"schema_migrations\"&x-migrations-table-quoted=1",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}

// make sure migrate.schema_migrations table exists
var exists bool
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'migrate')").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table migrate.schema_migrations to exist")
}

d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=migrate.schema_migrations",
pgPassword, ip, port))
if err != nil {
t.Fatal(err)
}
if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'migrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatalf("expected table 'migrate.schema_migrations' to exist")
}

})
}

func TestFailToCreateTableWithoutPermissions(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
Expand Down Expand Up @@ -373,6 +441,18 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) {
if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}

// re-connect using that x-migrations-table and x-migrations-table-quoted
d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1",
pgPassword, ip, port))

if !errors.As(err, &e) || err == nil {
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
}

if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}
})
}

Expand Down Expand Up @@ -679,5 +759,44 @@ func Test_computeLineFromPos(t *testing.T) {
run(true, true)
})
}
}

func Test_quoteIdentifier(t *testing.T) {
testcases := []struct {
migrationsTableQuoted bool
migrationsTable string
expected string
}{
{
false,
"schema_name.table_name",
"\"schema_name.table_name\"",
},
{
false,
"table_name",
"\"table_name\"",
},
{
true,
"\"schema_name\".\"table_name\"",
"\"schema_name\".\"table_name\"",
},
{
true,
"\"table_name\"",
"\"table_name\"",
},
}
p := &Postgres{
config: &Config{},
}

for _, tc := range testcases {
p.config.MigrationsTableQuoted = tc.migrationsTableQuoted
got := p.quoteIdentifier(tc.migrationsTable)
if tc.expected != got {
t.Fatalf("expected %s but got %s", tc.expected, got)
}
}
}
1 change: 1 addition & 0 deletions database/postgres/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
| URL Query | WithInstance Config | Description |
|------------|---------------------|-------------|
| `x-migrations-table` | `MigrationsTable` | Name of the migrations table |
| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, migrate quotes the migration table for SQL injection safety reasons. This option disable quoting and naively checks that you have quoted the migration table name. e.g. `"my_schema"."schema_migrations"` |
| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds |
| `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) |
| `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) |
Expand Down
Loading