Skip to content

Commit

Permalink
Postgres: Add a check to determine if table already exists to elide C…
Browse files Browse the repository at this point in the history
…REATE query (#526)

* Squash commits

* Format

* Minor refactoring

* Address PR feedback; Add mustRun

* Fix a test assert
  • Loading branch information
testtest959 authored Mar 19, 2021
1 parent 63aff4b commit 511ae9f
Show file tree
Hide file tree
Showing 4 changed files with 325 additions and 12 deletions.
19 changes: 18 additions & 1 deletion database/pgx/pgx.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,24 @@ func (p *Postgres) ensureVersionTable() (err error) {
}
}()

query := `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable)

err = row.Scan(&count)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

if count == 1 {
return nil
}

query = `CREATE TABLE IF NOT EXISTS ` + quoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
Expand Down
148 changes: 144 additions & 4 deletions database/pgx/pgx_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"context"
"database/sql"
sqldriver "database/sql/driver"
"errors"
"fmt"
"log"

Expand Down Expand Up @@ -76,6 +77,14 @@ func isReady(ctx context.Context, c dktest.ContainerInfo) bool {
return true
}

func mustRun(t *testing.T, d database.Driver, statements []string) {
for _, statement := range statements {
if err := d.Run(strings.NewReader(statement)); err != nil {
t.Fatal(err)
}
}
}

func Test(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
Expand Down Expand Up @@ -309,6 +318,141 @@ func TestWithSchema(t *testing.T) {
})
}

func TestFailToCreateTableWithoutPermissions(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

addr := pgConnectionString(ip, port)

// Check that opening the postgres connection returns NilVersion
p := &Postgres{}

d, err := p.Open(addr)

if err != nil {
t.Fatal(err)
}

defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()

// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
// since this is a test environment and we're not expecting to the pgPassword to be malicious
mustRun(t, d, []string{
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
})

// re-connect using that schema
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))

defer func() {
if d2 == nil {
return
}
if err := d2.Close(); err != nil {
t.Fatal(err)
}
}()

var e *database.Error
if !errors.As(err, &e) || err == nil {
t.Fatal("Unexpected error, want permission denied error. Got: ", err)
}

if !strings.Contains(e.OrigErr.Error(), "permission denied for schema barfoo") {
t.Fatal(e)
}
})
}

func TestCheckBeforeCreateTable(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
if err != nil {
t.Fatal(err)
}

addr := pgConnectionString(ip, port)

// Check that opening the postgres connection returns NilVersion
p := &Postgres{}

d, err := p.Open(addr)

if err != nil {
t.Fatal(err)
}

defer func() {
if err := d.Close(); err != nil {
t.Error(err)
}
}()

// create user who is not the owner. Although we're concatenating strings in an sql statement it should be fine
// since this is a test environment and we're not expecting to the pgPassword to be malicious
mustRun(t, d, []string{
"CREATE USER not_owner WITH ENCRYPTED PASSWORD '" + pgPassword + "'",
"CREATE SCHEMA barfoo AUTHORIZATION postgres",
"GRANT USAGE ON SCHEMA barfoo TO not_owner",
"GRANT CREATE ON SCHEMA barfoo TO not_owner",
})

// re-connect using that schema
d2, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))

if err != nil {
t.Fatal(err)
}

if err := d2.Close(); err != nil {
t.Fatal(err)
}

// revoke privileges
mustRun(t, d, []string{
"REVOKE CREATE ON SCHEMA barfoo FROM PUBLIC",
"REVOKE CREATE ON SCHEMA barfoo FROM not_owner",
})

// re-connect using that schema
d3, err := p.Open(fmt.Sprintf("postgres://not_owner:%s@%v:%v/postgres?sslmode=disable&search_path=barfoo",
pgPassword, ip, port))

if err != nil {
t.Fatal(err)
}

version, _, err := d3.Version()

if err != nil {
t.Fatal(err)
}

if version != database.NilVersion {
t.Fatal("Unexpected version, want database.NilVersion. Got: ", version)
}

defer func() {
if err := d3.Close(); err != nil {
t.Fatal(err)
}
}()
})
}

func TestParallelSchema(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
Expand Down Expand Up @@ -375,10 +519,6 @@ func TestParallelSchema(t *testing.T) {
})
}

func TestWithInstance(t *testing.T) {

}

func TestPostgres_Lock(t *testing.T) {
dktesting.ParallelTest(t, specs, func(t *testing.T, c dktest.ContainerInfo) {
ip, port, err := c.FirstPort()
Expand Down
19 changes: 18 additions & 1 deletion database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,24 @@ func (p *Postgres) ensureVersionTable() (err error) {
}
}()

query := `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
// This block checks whether the `MigrationsTable` already exists. This is useful because it allows read only postgres
// users to also check the current version of the schema. Previously, even if `MigrationsTable` existed, the
// `CREATE TABLE IF NOT EXISTS...` query would fail because the user does not have the CREATE permission.
// Taken from https://github.com/mattes/migrate/blob/master/database/postgres/postgres.go#L258
var count int
query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1`
row := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable)

err = row.Scan(&count)
if err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}

if count == 1 {
return nil
}

query = `CREATE TABLE IF NOT EXISTS ` + pq.QuoteIdentifier(p.config.MigrationsTable) + ` (version bigint not null primary key, dirty boolean not null)`
if _, err = p.conn.ExecContext(context.Background(), query); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
Expand Down
Loading

0 comments on commit 511ae9f

Please sign in to comment.