From 820ec91404a58dde22e6cf8e0e2de89557df61fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phane=20Klein?= Date: Sun, 29 Nov 2020 15:22:47 +0100 Subject: [PATCH] Postgres - Add schema name support in x-migrations-table value (#95) Add x-migrations-table-has-schema to enable schema name support in x-migrations-table value. --- database/postgres/README.md | 33 ++++++++++++----------- database/postgres/postgres.go | 43 ++++++++++++++++++++++-------- database/postgres/postgres_test.go | 41 ++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 27 deletions(-) diff --git a/database/postgres/README.md b/database/postgres/README.md index 1527ea354..8877ce285 100644 --- a/database/postgres/README.md +++ b/database/postgres/README.md @@ -2,22 +2,23 @@ `postgres://user:password@host:port/dbname?query` (`postgresql://` works, too) -| URL Query | WithInstance Config | Description | -|------------|---------------------|-------------| -| `x-migrations-table` | `MigrationsTable` | Name of the migrations table | -| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds | -| `dbname` | `DatabaseName` | The name of the database to connect to | -| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. | -| `user` | | The user to sign in as | -| `password` | | The user's password | -| `host` | | The host to connect to. Values that start with / are for unix domain sockets. (default is localhost) | -| `port` | | The port to bind to. (default is 5432) | -| `fallback_application_name` | | An application_name to fall back to if one isn't provided. | -| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. | -| `sslcert` | | Cert file location. The file must contain PEM encoded data. | -| `sslkey` | | Key file location. The file must contain PEM encoded data. | -| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. | -| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) | +| URL Query | WithInstance Config | Description | Default value | +|------------|---------------------|-------------|---------------| +| `x-migrations-table-has-schema` | `MigrationsTableHasSchema` | Enable schema name support in `x-migrations-table` parameter | `false` | +| `x-migrations-table` | `MigrationsTable` | Name of the migrations table, if `x-migrations-table-has-schema` is enabled then the first dot is treated as the schema and table name separator, for instance `gomigrate.schema_migrations` | `schema_migrations` | +| `x-statement-timeout` | `StatementTimeout` | Abort any statement that takes more than the specified number of milliseconds | `0` | +| `dbname` | `DatabaseName` | The name of the database to connect to | `SELECT CURRENT_DATABASE()` result | +| `search_path` | | This variable specifies the order in which schemas are searched when an object is referenced by a simple name with no schema specified. | `SHOW search_path` result | +| `user` | | The user to sign in as | | +| `password` | | The user's password | | +| `host` | | The host to connect to. Values that start with / are for unix domain sockets | `localhost` | +| `port` | | The port to bind to| `5432` | +| `fallback_application_name` | | An application_name to fall back to if one isn't provided. | | +| `connect_timeout` | | Maximum wait for connection, in seconds. Zero or not specified means wait indefinitely. | | +| `sslcert` | | Cert file location. The file must contain PEM encoded data. | | +| `sslkey` | | Key file location. The file must contain PEM encoded data. | | +| `sslrootcert` | | The location of the root certificate file. The file must contain PEM encoded data. | | +| `sslmode` | | Whether or not to use SSL (disable\|require\|verify-ca\|verify-full) | | ## Upgrading from v1 diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index f111f36ae..40585e4c7 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -35,10 +35,11 @@ var ( ) type Config struct { - MigrationsTable string - DatabaseName string - SchemaName string - StatementTimeout time.Duration + MigrationsTableHasSchema bool + MigrationsTable string + DatabaseName string + SchemaName string + StatementTimeout time.Duration } type Postgres struct { @@ -124,6 +125,7 @@ func (p *Postgres) Open(url string) (database.Driver, error) { migrationsTable := purl.Query().Get("x-migrations-table") statementTimeoutString := purl.Query().Get("x-statement-timeout") + migrationsTableHasSchemaString := purl.Query().Get("x-migrations-table-has-schema") statementTimeout := 0 if statementTimeoutString != "" { statementTimeout, err = strconv.Atoi(statementTimeoutString) @@ -132,10 +134,19 @@ func (p *Postgres) Open(url string) (database.Driver, error) { } } + migrationsTableHasSchema := false + if migrationsTableHasSchemaString != "" { + migrationsTableHasSchema, err = strconv.ParseBool(migrationsTableHasSchemaString) + if err != nil { + return nil, err + } + } + px, err := WithInstance(db, &Config{ - DatabaseName: purl.Path, - MigrationsTable: migrationsTable, - StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, + DatabaseName: purl.Path, + MigrationsTableHasSchema: migrationsTableHasSchema, + MigrationsTable: migrationsTable, + StatementTimeout: time.Duration(statementTimeout) * time.Millisecond, }) if err != nil { @@ -266,13 +277,23 @@ func runesLastIndex(input []rune, target rune) int { return -1 } +func (p *Postgres) quoteIdentifier(name string) string { + if p.config.MigrationsTableHasSchema { + firstDotPosition := strings.Index(name, ".") + if firstDotPosition != -1 { + return pq.QuoteIdentifier(name[0:firstDotPosition]) + "." + pq.QuoteIdentifier(name[firstDotPosition+1:]) + } + } + return pq.QuoteIdentifier(name) +} + func (p *Postgres) SetVersion(version int, dirty bool) error { tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{}) if err != nil { return &database.Error{OrigErr: err, Err: "transaction start failed"} } - query := `TRUNCATE ` + pq.QuoteIdentifier(p.config.MigrationsTable) + query := `TRUNCATE ` + p.quoteIdentifier(p.config.MigrationsTable) if _, err := tx.Exec(query); err != nil { if errRollback := tx.Rollback(); errRollback != nil { err = multierror.Append(err, errRollback) @@ -284,7 +305,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { // empty schema version for failed down migration on the first migration // See: https://github.com/golang-migrate/migrate/issues/330 if version >= 0 || (version == database.NilVersion && dirty) { - query = `INSERT INTO ` + pq.QuoteIdentifier(p.config.MigrationsTable) + + query = `INSERT INTO ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version, dirty) VALUES ($1, $2)` if _, err := tx.Exec(query, version, dirty); err != nil { if errRollback := tx.Rollback(); errRollback != nil { @@ -302,7 +323,7 @@ func (p *Postgres) SetVersion(version int, dirty bool) error { } func (p *Postgres) Version() (version int, dirty bool, err error) { - query := `SELECT version, dirty FROM ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1` + query := `SELECT version, dirty FROM ` + p.quoteIdentifier(p.config.MigrationsTable) + ` LIMIT 1` err = p.conn.QueryRowContext(context.Background(), query).Scan(&version, &dirty) switch { case err == sql.ErrNoRows: @@ -380,7 +401,7 @@ func (p *Postgres) ensureVersionTable() (err error) { } }() - query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` + query := `CREATE TABLE IF NOT EXISTS ` + p.quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)` if _, err = p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index ecde45171..1868eaa17 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -508,5 +508,46 @@ func Test_computeLineFromPos(t *testing.T) { run(true, true) }) } +} +func Test_quoteIdentifier(t *testing.T) { + testcases := []struct { + migrationsTableHasSchema bool + name string + want string + }{ + { + true, + "schema_name.table_name", + "\"schema_name\".\"table_name\"", + }, + { + true, + "schema_name.table.name", + "\"schema_name\".\"table.name\"", + }, + { + false, + "table_name", + "\"table_name\"", + }, + { + false, + "table.name", + "\"table.name\"", + }, + } + p := &Postgres{ + config: &Config{ + MigrationsTableHasSchema: true, + }, + } + + for _, tc := range testcases { + p.config.MigrationsTableHasSchema = tc.migrationsTableHasSchema + got := p.quoteIdentifier(tc.name) + if tc.want != got { + t.Fatalf("expected %s but got %s", tc.want, got) + } + } }