From cf2fbf4aba261189258a0f84c217aec3bf9331bd Mon Sep 17 00:00:00 2001 From: Stephane Klein Date: Sun, 28 Mar 2021 22:15:13 +0200 Subject: [PATCH] Postgres and pgx - Add x-migrations-table-quoted url query option to postgres and pgx drivers (#95) By default, gomigrate quote migrations table name, if `x-migrations-table-quoted` is enabled, then you must to quote migrations table name manually, for instance `"gomigrate"."schema_migrations"` --- database/pgx/README.md | 1 + database/pgx/pgx.go | 52 +++++++++++--- database/pgx/pgx_test.go | 107 +++++++++++++++++++++++++++ database/postgres/README.md | 1 + database/postgres/postgres.go | 57 ++++++++++++--- database/postgres/postgres_test.go | 111 ++++++++++++++++++++++++++++- 6 files changed, 309 insertions(+), 20 deletions(-) diff --git a/database/pgx/README.md b/database/pgx/README.md index dca317fdc..2e9e7c75c 100644 --- a/database/pgx/README.md +++ b/database/pgx/README.md @@ -5,6 +5,7 @@ | URL Query | WithInstance Config | Description | |------------|---------------------|-------------| | `x-migrations-table` | `MigrationsTable` | Name of the migrations table | +| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, gomigrate quote migrations table name, if `x-migrations-table-quoted` is enabled, then you must quote migrations table name manually, for instance `"gomigrate"."schema_migrations"` | | `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds | | `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) | | `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) | diff --git a/database/pgx/pgx.go b/database/pgx/pgx.go index dda70a77c..080906f6a 100644 --- a/database/pgx/pgx.go +++ b/database/pgx/pgx.go @@ -9,6 +9,7 @@ import ( "io" "io/ioutil" nurl "net/url" + "regexp" "strconv" "strings" "time" @@ -46,6 +47,7 @@ type Config struct { DatabaseName string SchemaName string StatementTimeout time.Duration + MigrationsTableQuoted bool MultiStatementEnabled bool MultiStatementMaxSize int } @@ -137,6 +139,17 @@ func (p *Postgres) Open(url string) (database.Driver, error) { } migrationsTable := purl.Query().Get("x-migrations-table") + migrationsTableQuoted := false + if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 { + migrationsTableQuoted, err = strconv.ParseBool(s) + if err != nil { + return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err) + } + } + if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) { + return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"gomigrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable) + } + statementTimeoutString := purl.Query().Get("x-statement-timeout") statementTimeout := 0 if statementTimeoutString != "" { @@ -168,6 +181,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) { px, err := WithInstance(db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, + MigrationsTableQuoted: migrationsTableQuoted, StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, MultiStatementEnabled: multiStatementEnabled, MultiStatementMaxSize: multiStatementMaxSize, @@ -321,7 +335,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { return &database.Error{OrigErr: err, Err: "transaction start failed"} } - query := `TRUNCATE ` + quoteIdentifier(p.config.MigrationsTable) + query := `TRUNCATE ` + p.quoteIdentifier(p.config.MigrationsTable) if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { err = multierror.Append(err, errRollback) @@ -333,7 +347,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { // empty schema version for failed down migration on the first migration // See: https://github.com/golang-migrate/migrate/issues/330 if version >= 0 || (version == database.NilVersion && dirty) { - query = `INSERT INTO ` + quoteIdentifier(p.config.MigrationsTable) + + query = `INSERT INTO ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES ($1, $2)` if _, err := tx.Exec(query, version, dirty); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -351,7 +365,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { } func (p *Postgres) Version() (version int, dirty bool, err error) { - query := `SELECT version, dirty FROM ` + quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1` + query := `SELECT version, dirty FROM ` + p.quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1` err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) switch { case err == sql.ErrNoRows: @@ -401,7 +415,7 @@ func (p *Postgres) Drop() (err error) { if len(tableNames) > 0 { // delete one by one ... for _, t := range tableNames { - query = `DROP TABLE IF EXISTS ` + quoteIdentifier(t) + ` CASCADE` + query = `DROP TABLE IF EXISTS ` + p.quoteIdentifier(t) + ` CASCADE` if _, err := p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } @@ -433,10 +447,27 @@ func (p *Postgres) ensureVersionTable() (err error) { // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 - var count int - query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` - row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable) + var row *sql.Row + tableName := []byte(p.config.MigrationsTable) + schemaName := []byte("") + if p.config.MigrationsTableQuoted { + re := regexp.MustCompile(`"(.*?)"`) + result := re.FindAllSubmatch([]byte(p.config.MigrationsTable), -1) + tableName = result[len(result)-1][1] + if len(result) > 1 { + schemaName = result[0][1] + } + } + var query string + if len(schemaName) > 0 { + query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2 LIMIT 1` + row = p.conn.QueryRowContext(context.Background(), query, tableName, schemaName) + } else { + query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` + row = p.conn.QueryRowContext(context.Background(), query, tableName) + } + var count int err = row.Scan(&count) if err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} @@ -446,7 +477,7 @@ func (p *Postgres) ensureVersionTable() (err error) { return nil } - query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` + query = `CREATE TABLE IF NOT EXISTS ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` if _, err = p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } @@ -455,7 +486,10 @@ func (p *Postgres) ensureVersionTable() (err error) { } // Copied from lib/pq implementation: https://github.com/lib/pq/blob/v1.9.0/conn.go#L1611 -func quoteIdentifier(name string) string { +func (p *Postgres) quoteIdentifier(name string) string { + if p.config.MigrationsTableQuoted { + return name + } end := strings.IndexRune(name, 0) if end > -1 { name = name[:end] diff --git a/database/pgx/pgx_test.go b/database/pgx/pgx_test.go index a79f9611e..1de7d9d08 100644 --- a/database/pgx/pgx_test.go +++ b/database/pgx/pgx_test.go @@ -318,6 +318,65 @@ func TestWithSchema(t *testing.T) { }) } +func TestMigrationTableOption(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + p := &Postgres{} + d, _ := p.Open(addr) + defer func() { + if err := d.Close(); err != nil { + t.Fatal(err) + } + }() + + // create gomigrate schema + if err := d.Run(strings.NewReader("CREATE SCHEMA gomigrate AUTHORIZATION postgres")); err != nil { + t.Fatal(err) + } + + // bad unquoted x-migrations-table parameter + d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=gomigrate.schema_migrations&x-migrations-table-quoted=1", + pgPassword, ip, port)) + if err == nil { + t.Fatal("expected x-migrations-table must be quoted...") + } + + // good quoted x-migrations-table parameter + d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"gomigrate\".\"schema_migrations\"&x-migrations-table-quoted=1", + pgPassword, ip, port)) + if err != nil { + t.Fatal(err) + } + + // make sure gomigrate.schema_migrations table exists + var exists bool + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'gomigrate')").Scan(&exists); err != nil { + t.Fatal(err) + } + if !exists { + t.Fatalf("expected table gomigrate.schema_migrations to exist") + } + + d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=gomigrate.schema_migrations", + pgPassword, ip, port)) + if err != nil { + t.Fatal(err) + } + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'gomigrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { + t.Fatal(err) + } + if !exists { + t.Fatalf("expected table 'gomigrate.schema_migrations' to exist") + } + + }) +} + func TestFailToCreateTableWithoutPermissions(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() @@ -373,6 +432,18 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) { if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") { t.Fatal(e) } + + // re-connect using that x-migrations-table and x-migrations-table-quoted + d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1", + pgPassword, ip, port)) + + if !errors.As(err, &e) || err == nil { + t.Fatal("Unexpected error, want permission denied error. Got: ", err) + } + + if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") { + t.Fatal(e) + } }) } @@ -679,5 +750,41 @@ func Test_computeLineFromPos(t *testing.T) { run(true, true) }) } +} +func Test_quoteIdentifier(t *testing.T) { + testcases := []struct { + migrationsTableQuoted bool + migrationsTable string + want string + }{ + { + false, + "schema_name.table_name", + "\"schema_name.table_name\"", + }, + { + false, + "table_name", + "\"table_name\"", + }, + { + true, + "\"schema_name\".\"table.name\"", + "\"schema_name\".\"table.name\"", + }, + } + p := &Postgres{ + config: &Config{ + MigrationsTableQuoted: false, + }, + } + + for _, tc := range testcases { + p.config.MigrationsTableQuoted = tc.migrationsTableQuoted + got := p.quoteIdentifier(tc.migrationsTable) + if tc.want != got { + t.Fatalf("expected %s but got %s", tc.want, got) + } + } } diff --git a/database/postgres/README.md b/database/postgres/README.md index e60a5f6cb..94cfb790f 100644 --- a/database/postgres/README.md +++ b/database/postgres/README.md @@ -5,6 +5,7 @@ | URL Query | WithInstance Config | Description | |------------|---------------------|-------------| | `x-migrations-table` | `MigrationsTable` | Name of the migrations table | +| `x-migrations-table-quoted` | `MigrationsTableQuoted` | By default, gomigrate quote migrations table name, if `x-migrations-table-quoted` is enabled, then you must quote migrations table name manually, for instance `"gomigrate"."schema_migrations"` | | `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds | | `x-multi-statement` | `MultiStatementEnabled` | Enable multi-statement execution (default: false) | | `x-multi-statement-max-size` | `MultiStatementMaxSize` | Maximum size of single statement in bytes (default: 10MB) | diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 4926e1d84..75bc1a837 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -9,6 +9,7 @@ import ( "io" "io/ioutil" nurl "net/url" + "regexp" "strconv" "strings" "time" @@ -42,10 +43,11 @@ var ( type Config struct { MigrationsTable string + MigrationsTableQuoted bool + MultiStatementEnabled bool DatabaseName string SchemaName string StatementTimeout time.Duration - MultiStatementEnabled bool MultiStatementMaxSize int } @@ -99,7 +101,6 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { if len(config.MigrationsTable) == 0 { config.MigrationsTable = DefaultMigrationsTable } - conn, err := instance.Conn(context.Background()) if err != nil { @@ -131,6 +132,17 @@ func (p *Postgres) Open(url string) (database.Driver, error) { } migrationsTable := purl.Query().Get("x-migrations-table") + migrationsTableQuoted := false + if s := purl.Query().Get("x-migrations-table-quoted"); len(s) > 0 { + migrationsTableQuoted, err = strconv.ParseBool(s) + if err != nil { + return nil, fmt.Errorf("Unable to parse option x-migrations-table-quoted: %w", err) + } + } + if (len(migrationsTable) > 0) && (migrationsTableQuoted) && ((migrationsTable[0] != '"') || (migrationsTable[len(migrationsTable)-1] != '"')) { + return nil, fmt.Errorf("x-migrations-table must be quoted (for instance '\"gomigrate\".\"schema_migrations\"') when x-migrations-table-quoted is enabled, current value is: %s", migrationsTable) + } + statementTimeoutString := purl.Query().Get("x-statement-timeout") statementTimeout := 0 if statementTimeoutString != "" { @@ -162,6 +174,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) { px, err := WithInstance(db, &Config{ DatabaseName: purl.Path, MigrationsTable: migrationsTable, + MigrationsTableQuoted: migrationsTableQuoted, StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, MultiStatementEnabled: multiStatementEnabled, MultiStatementMaxSize: multiStatementMaxSize, @@ -312,13 +325,20 @@ func runesLastIndex(input []rune, target rune) int { return -1 } +func (p *Postgres) quoteIdentifier(name string) string { + if p.config.MigrationsTableQuoted { + return name + } + return pq.QuoteIdentifier(name) +} + func (p *Postgres) SetVersion(version int, dirty bool) error { tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} } - query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable) + query := `TRUNCATE ` + p.quoteIdentifier(p.config.MigrationsTable) if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { err = multierror.Append(err, errRollback) @@ -330,7 +350,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { // empty schema version for failed down migration on the first migration // See: https://github.com/golang-migrate/migrate/issues/330 if version >= 0 || (version == database.NilVersion && dirty) { - query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + + query = `INSERT INTO ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES ($1, $2)` if _, err := tx.Exec(query, version, dirty); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -348,7 +368,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { } func (p *Postgres) Version() (version int, dirty bool, err error) { - query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1` + query := `SELECT version, dirty FROM ` + p.quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1` err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) switch { case err == sql.ErrNoRows: @@ -398,7 +418,7 @@ func (p *Postgres) Drop() (err error) { if len(tableNames) > 0 { // delete one by one ... for _, t := range tableNames { - query = `DROP TABLE IF EXISTS ` + pq.QuoteIdentifier(t) + ` CASCADE` + query = `DROP TABLE IF EXISTS ` + p.quoteIdentifier(t) + ` CASCADE` if _, err := p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } @@ -430,10 +450,27 @@ func (p *Postgres) ensureVersionTable() (err error) { // users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the // `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission. // Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258 - var count int - query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` - row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable) + var row *sql.Row + tableName := []byte(p.config.MigrationsTable) + schemaName := []byte("") + if p.config.MigrationsTableQuoted { + re := regexp.MustCompile(`"(.*?)"`) + result := re.FindAllSubmatch([]byte(p.config.MigrationsTable), -1) + tableName = result[len(result)-1][1] + if len(result) > 1 { + schemaName = result[0][1] + } + } + var query string + if len(schemaName) > 0 { + query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = $2 LIMIT 1` + row = p.conn.QueryRowContext(context.Background(), query, tableName, schemaName) + } else { + query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` + row = p.conn.QueryRowContext(context.Background(), query, tableName) + } + var count int err = row.Scan(&count) if err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} @@ -443,7 +480,7 @@ func (p *Postgres) ensureVersionTable() (err error) { return nil } - query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` + query = `CREATE TABLE IF NOT EXISTS ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` if _, err = p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index c6cc9e7f8..46829248f 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -8,7 +8,6 @@ import ( sqldriver "database/sql/driver" "errors" "fmt" - "github.com/golang-migrate/migrate/v4" "io" "log" "strconv" @@ -16,6 +15,8 @@ import ( "sync" "testing" + "github.com/golang-migrate/migrate/v4" + "github.com/dhui/dktest" "github.com/golang-migrate/migrate/v4/database" @@ -318,6 +319,65 @@ func TestWithSchema(t *testing.T) { }) } +func TestMigrationTableOption(t *testing.T) { + dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { + ip, port, err := c.FirstPort() + if err != nil { + t.Fatal(err) + } + + addr := pgConnectionString(ip, port) + p := &Postgres{} + d, _ := p.Open(addr) + defer func() { + if err := d.Close(); err != nil { + t.Fatal(err) + } + }() + + // create gomigrate schema + if err := d.Run(strings.NewReader("CREATE SCHEMA gomigrate AUTHORIZATION postgres")); err != nil { + t.Fatal(err) + } + + // bad unquoted x-migrations-table parameter + d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=gomigrate.schema_migrations&x-migrations-table-quoted=1", + pgPassword, ip, port)) + if err == nil { + t.Fatal("expected x-migrations-table must be quoted...") + } + + // good quoted x-migrations-table parameter + d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"gomigrate\".\"schema_migrations\"&x-migrations-table-quoted=1", + pgPassword, ip, port)) + if err != nil { + t.Fatal(err) + } + + // make sure gomigrate.schema_migrations table exists + var exists bool + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'schema_migrations' AND table_schema = 'gomigrate')").Scan(&exists); err != nil { + t.Fatal(err) + } + if !exists { + t.Fatalf("expected table gomigrate.schema_migrations to exist") + } + + d, err = p.Open(fmt.Sprintf("postgres://postgres:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=gomigrate.schema_migrations", + pgPassword, ip, port)) + if err != nil { + t.Fatal(err) + } + if err := d.(*Postgres).conn.QueryRowContext(context.Background(), "SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'gomigrate.schema_migrations' AND table_schema = (SELECT current_schema()))").Scan(&exists); err != nil { + t.Fatal(err) + } + if !exists { + t.Fatalf("expected table 'gomigrate.schema_migrations' to exist") + } + + }) +} + func TestFailToCreateTableWithoutPermissions(t *testing.T) { dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) { ip, port, err := c.FirstPort() @@ -373,6 +433,18 @@ func TestFailToCreateTableWithoutPermissions(t *testing.T) { if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") { t.Fatal(e) } + + // re-connect using that x-migrations-table and x-migrations-table-quoted + d2, err = p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&x-migrations-table=\"barfoo\".\"schema_migrations\"&x-migrations-table-quoted=1", + pgPassword, ip, port)) + + if !errors.As(err, &e) || err == nil { + t.Fatal("Unexpected error, want permission denied error. Got: ", err) + } + + if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") { + t.Fatal(e) + } }) } @@ -683,3 +755,40 @@ func Test_computeLineFromPos(t *testing.T) { } } + +func Test_quoteIdentifier(t *testing.T) { + testcases := []struct { + migrationsTableQuoted bool + migrationsTable string + want string + }{ + { + false, + "schema_name.table_name", + "\"schema_name.table_name\"", + }, + { + false, + "table_name", + "\"table_name\"", + }, + { + true, + "\"schema_name\".\"table.name\"", + "\"schema_name\".\"table.name\"", + }, + } + p := &Postgres{ + config: &Config{ + MigrationsTableQuoted: false, + }, + } + + for _, tc := range testcases { + p.config.MigrationsTableQuoted = tc.migrationsTableQuoted + got := p.quoteIdentifier(tc.migrationsTable) + if tc.want != got { + t.Fatalf("expected %s but got %s", tc.want, got) + } + } +}