Skip to content

Commit

Permalink
feat: overriding system value for pg identity column (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
amakmurr authored May 25, 2024
1 parent 1df598f commit e867d65
Show file tree
Hide file tree
Showing 7 changed files with 171 additions and 6 deletions.
3 changes: 2 additions & 1 deletion cockroachdb_test.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
//go:build cockroachdb
// +build cockroachdb

package testfixtures
Expand All @@ -16,7 +17,7 @@ func TestCockroachDB(t *testing.T) {
t,
dialect,
os.Getenv("CRDB_CONN_STRING"),
"testdata/schema/postgresql.sql",
"testdata/schema/cockroachdb.sql",
DangerousSkipTestDatabaseCheck(),
UseDropConstraint(),
)
Expand Down
11 changes: 11 additions & 0 deletions helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package testfixtures
import (
"database/sql"
"fmt"
"strings"
)

const (
Expand All @@ -24,6 +25,7 @@ type helper interface {
quoteKeyword(string) string
whileInsertOnTable(*sql.Tx, string, func() error) error
cleanTableQuery(string) string
buildInsertSQL(q queryable, tableName string, columns, values []string) (string, error)
}

type queryable interface {
Expand Down Expand Up @@ -75,3 +77,12 @@ func (baseHelper) afterLoad(_ queryable) error {
func (baseHelper) cleanTableQuery(tableName string) string {
return fmt.Sprintf("DELETE FROM %s", tableName)
}

func (h baseHelper) buildInsertSQL(_ queryable, tableName string, columns, values []string) (string, error) {
return fmt.Sprintf(
"INSERT INTO %s (%s) VALUES (%s)",
tableName,
strings.Join(columns, ", "),
strings.Join(values, ", "),
), nil
}
4 changes: 4 additions & 0 deletions mocks_for_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ func (h *MockHelper) cleanTableQuery(string) string {
return ""
}

func (h *MockHelper) buildInsertSQL(queryable, string, []string, []string) (string, error) {
return "", nil
}

// NewMockHelper returns MockHelper
func NewMockHelper(dbName string) *MockHelper {
return &MockHelper{dbName: dbName}
Expand Down
91 changes: 91 additions & 0 deletions postgresql.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package testfixtures
import (
"database/sql"
"fmt"
"regexp"
"strconv"
"strings"
"sync"
)

type postgreSQL struct {
Expand All @@ -19,6 +22,10 @@ type postgreSQL struct {
nonDeferrableConstraints []pgConstraint
constraints []pgConstraint
tablesChecksum map[string]string

version int
tablesHasIdentityColumnMutex sync.Mutex
tablesHasIdentityColumn map[string]bool
}

type pgConstraint struct {
Expand Down Expand Up @@ -50,6 +57,13 @@ func (h *postgreSQL) init(db *sql.DB) error {
return err
}

h.version, err = h.getMajorVersion(db)
if err != nil {
return err
}

h.tablesHasIdentityColumn = make(map[string]bool)

return nil
}

Expand Down Expand Up @@ -383,3 +397,80 @@ func (*postgreSQL) quoteKeyword(s string) string {
}
return strings.Join(parts, ".")
}

func (h *postgreSQL) buildInsertSQL(q queryable, tableName string, columns, values []string) (string, error) {
if h.version >= 10 {
ok, err := h.tableHasIdentityColumn(q, tableName)
if err != nil {
return "", err
}
if ok {
return fmt.Sprintf(
"INSERT INTO %s (%s) OVERRIDING SYSTEM VALUE VALUES (%s)",
tableName,
strings.Join(columns, ", "),
strings.Join(values, ", "),
), nil
}
}

return fmt.Sprintf(
"INSERT INTO %s (%s) VALUES (%s)",
tableName,
strings.Join(columns, ", "),
strings.Join(values, ", "),
), nil
}

func (h *postgreSQL) tableHasIdentityColumn(q queryable, tableName string) (bool, error) {
defer h.tablesHasIdentityColumnMutex.Unlock()
h.tablesHasIdentityColumnMutex.Lock()

hasIdentityColumn, exists := h.tablesHasIdentityColumn[tableName]
if exists {
return hasIdentityColumn, nil
}

parts := strings.Split(tableName, ".")
tableName = parts[0][1 : len(parts[0])-1]
if len(parts) > 1 {
tableName = parts[1][1 : len(parts[1])-1]
}

query := fmt.Sprintf(`
SELECT COUNT(*) AS count
FROM information_schema.columns
WHERE table_name = '%s' AND is_identity = 'YES'
`, tableName)
var count int
if err := q.QueryRow(query).Scan(&count); err != nil {
return false, err
}

h.tablesHasIdentityColumn[tableName] = count > 0
return h.tablesHasIdentityColumn[tableName], nil
}

func (h *postgreSQL) getMajorVersion(q queryable) (int, error) {
var version string
err := q.QueryRow("SELECT VERSION()").Scan(&version)
if err != nil {
return 0, err
}

return h.parseMajorVersion(version)
}

func (*postgreSQL) parseMajorVersion(version string) (int, error) {
re := regexp.MustCompile(`\d+`)
versionNumbers := re.FindAllString(version, -1)
if len(versionNumbers) > 0 {
majorVersion, err := strconv.Atoi(versionNumbers[0])
if err != nil {
return 0, err
}
return majorVersion, nil
}

return 0, fmt.Errorf("testfixtures: could not parse major version from: %s", version)
}
59 changes: 59 additions & 0 deletions testdata/schema/cockroachdb.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
DROP TABLE IF EXISTS votes;
DROP TABLE IF EXISTS comments;
DROP TABLE IF EXISTS posts_tags;
DROP TABLE IF EXISTS posts;
DROP TABLE IF EXISTS tags;
DROP TABLE IF EXISTS users;
DROP TABLE IF EXISTS assets;

CREATE TABLE posts (
id SERIAL PRIMARY KEY
,title VARCHAR(255) NOT NULL
,content TEXT NOT NULL
,created_at TIMESTAMP NOT NULL
,updated_at TIMESTAMP NOT NULL
);

CREATE TABLE tags (
id SERIAL PRIMARY KEY
,name VARCHAR(255) NOT NULL
,created_at TIMESTAMP NOT NULL
,updated_at TIMESTAMP NOT NULL
);

CREATE TABLE posts_tags (
post_id INTEGER NOT NULL
,tag_id INTEGER NOT NULL
,PRIMARY KEY (post_id, tag_id)
,FOREIGN KEY (post_id) REFERENCES posts (id) ON DELETE CASCADE
,FOREIGN KEY (tag_id) REFERENCES tags (id) ON DELETE CASCADE
);

CREATE TABLE comments (
id SERIAL PRIMARY KEY NOT NULL
,post_id INTEGER NOT NULL
,author_name VARCHAR(255) NOT NULL
,author_email VARCHAR(255) NOT NULL
,content TEXT NOT NULL
,created_at TIMESTAMP NOT NULL
,updated_at TIMESTAMP NOT NULL
,FOREIGN KEY (post_id) REFERENCES posts (id) ON DELETE CASCADE
);

CREATE TABLE votes (
id SERIAL PRIMARY KEY NOT NULL
,comment_id INTEGER NOT NULL
,created_at TIMESTAMP NOT NULL
,updated_at TIMESTAMP NOT NULL
,FOREIGN KEY (comment_id) REFERENCES comments (id) ON DELETE CASCADE
);

CREATE TABLE users (
id SERIAL PRIMARY KEY NOT NULL
,attributes JSONB NOT NULL
);

CREATE TABLE assets (
id SERIAL PRIMARY KEY NOT NULL
,data BYTEA NOT NULL
);
2 changes: 1 addition & 1 deletion testdata/schema/postgresql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ DROP TABLE IF EXISTS users;
DROP TABLE IF EXISTS assets;

CREATE TABLE posts (
id SERIAL PRIMARY KEY
id INT GENERATED ALWAYS AS IDENTITY PRIMARY KEY
,title VARCHAR(255) NOT NULL
,content TEXT NOT NULL
,created_at TIMESTAMP NOT NULL
Expand Down
7 changes: 3 additions & 4 deletions testfixtures.go
Original file line number Diff line number Diff line change
Expand Up @@ -584,11 +584,10 @@ func (l *Loader) buildInsertSQL(f *fixtureFile, record map[string]interface{}) (
i++
}

sqlStr = fmt.Sprintf(
"INSERT INTO %s (%s) VALUES (%s)",
sqlStr, err = l.helper.buildInsertSQL(
l.db,
l.helper.quoteKeyword(f.fileNameWithoutExtension()),
strings.Join(sqlColumns, ", "),
strings.Join(sqlValues, ", "),
sqlColumns, sqlValues,
)
return
}
Expand Down

0 comments on commit e867d65

Please sign in to comment.