Skip to content

Commit

Permalink
Duplicate foreign key constraints
Browse files Browse the repository at this point in the history
Ensure that foreign key constraints, including multi-column constraints,
are duplicated correctly.
  • Loading branch information
andrew-farries committed Nov 22, 2024
1 parent d12406a commit 9fe8d9b
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 32 deletions.
64 changes: 32 additions & 32 deletions pkg/migrations/duplicate.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const (
cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)`
cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`
cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID`
cAlterTableAddForeignKeySQL = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
)

// NewColumnDuplicator creates a new Duplicator for a column.
Expand Down Expand Up @@ -91,7 +92,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
colNames = append(colNames, name)

// Duplicate the column with the new type
// and check and fk constraints
if sql := d.stmtBuilder.duplicateColumn(c.column, c.asName, c.withoutNotNull, c.withType, d.withoutConstraint); sql != "" {
_, err := d.conn.ExecContext(ctx, sql)
if err != nil {
Expand All @@ -108,6 +108,7 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
}
}

// Duplicate the column's comment
if sql := d.stmtBuilder.duplicateComment(c.column, c.asName); sql != "" {
_, err := d.conn.ExecContext(ctx, sql)
if err != nil {
Expand All @@ -120,7 +121,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
// if the check constraint is not valid for the new column type, in which case
// the error is ignored.
for _, sql := range d.stmtBuilder.duplicateCheckConstraints(d.withoutConstraint, colNames...) {
// Update the check constraint expression to use the new column names if any of the columns are duplicated
_, err := d.conn.ExecContext(ctx, sql)
err = errorIgnoringErrorCode(err, undefinedFunctionErrorCode)
if err != nil {
Expand All @@ -132,12 +132,21 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
// The constraint is duplicated by adding a unique index on the column concurrently.
// The index is converted into a unique constraint on migration completion.
for _, sql := range d.stmtBuilder.duplicateUniqueConstraints(d.withoutConstraint, colNames...) {
// Update the unique constraint columns to use the new column names if any of the columns are duplicated
if _, err := d.conn.ExecContext(ctx, sql); err != nil {
return err
}
}

// Generate SQL to duplicate any foreign key constraints on the columns.
// If the foreign key constraint is not valid for a new column type, the error is ignored.
for _, sql := range d.stmtBuilder.duplicateForeignKeyConstraints(d.withoutConstraint, colNames...) {
_, err := d.conn.ExecContext(ctx, sql)
err = errorIgnoringErrorCode(err, dataTypeMismatchErrorCode)
if err != nil {
return err
}
}

return nil
}

Expand Down Expand Up @@ -175,6 +184,26 @@ func (d *duplicatorStmtBuilder) duplicateUniqueConstraints(withoutConstraint []s
return stmts
}

func (d *duplicatorStmtBuilder) duplicateForeignKeyConstraints(withoutConstraint []string, colNames ...string) []string {
stmts := make([]string, 0, len(d.table.ForeignKeys))
for _, fk := range d.table.ForeignKeys {
if slices.Contains(withoutConstraint, fk.Name) {
continue
}
if duplicatedMember, constraintColumns := d.allConstraintColumns(fk.Columns, colNames...); duplicatedMember {
stmts = append(stmts, fmt.Sprintf(cAlterTableAddForeignKeySQL,
pq.QuoteIdentifier(d.table.Name),
pq.QuoteIdentifier(DuplicationName(fk.Name)),
strings.Join(quoteColumnNames(constraintColumns), ", "),
pq.QuoteIdentifier(fk.ReferencedTable),
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "),
fk.OnDelete,
))
}
}
return stmts
}

// duplicatedConstraintColumns returns a new slice of constraint columns with
// the columns that are duplicated replaced with temporary names.
func (d *duplicatorStmtBuilder) duplicatedConstraintColumns(constraintColumns []string, duplicatedColumns ...string) []string {
Expand Down Expand Up @@ -213,7 +242,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
) string {
const (
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
)

Expand All @@ -232,23 +260,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
)
}

// Generate SQL to duplicate any foreign key constraints on the column
for _, fk := range d.table.ForeignKeys {
if slices.Contains(withoutConstraint, fk.Name) {
continue
}

if slices.Contains(fk.Columns, column.Name) {
sql += fmt.Sprintf(", "+cAddForeignKeySQL,
pq.QuoteIdentifier(DuplicationName(fk.Name)),
strings.Join(quoteColumnNames(copyAndReplace(fk.Columns, column.Name, asName)), ", "),
pq.QuoteIdentifier(fk.ReferencedTable),
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "),
fk.OnDelete,
)
}
}

return sql
}

Expand Down Expand Up @@ -295,17 +306,6 @@ func StripDuplicationPrefix(name string) string {
return strings.TrimPrefix(name, "_pgroll_dup_")
}

func copyAndReplace(xs []string, oldValue, newValue string) []string {
ys := slices.Clone(xs)

for i, c := range ys {
if c == oldValue {
ys[i] = newValue
}
}
return ys
}

func errorIgnoringErrorCode(err error, code pq.ErrorCode) error {
pqErr := &pq.Error{}
if ok := errors.As(err, &pqErr); ok {
Expand Down
53 changes: 53 additions & 0 deletions pkg/migrations/duplicate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ var table = &schema.Table{
"new_york_adults": {Name: "new_york_adults", Columns: []string{"city", "age"}, Definition: `"city" = 'New York' AND "age" > 21`},
"different_nick": {Name: "different_nick", Columns: []string{"name", "nick"}, Definition: `"name" != "nick"`},
},
ForeignKeys: map[string]schema.ForeignKey{
"fk_city": {Name: "fk_city", Columns: []string{"city"}, ReferencedTable: "cities", ReferencedColumns: []string{"id"}, OnDelete: "NO ACTION"},
"fk_name_nick": {Name: "fk_name_nick", Columns: []string{"name", "nick"}, ReferencedTable: "users", ReferencedColumns: []string{"name", "nick"}, OnDelete: "CASCADE"},
},
}

func TestDuplicateStmtBuilderCheckConstraints(t *testing.T) {
Expand Down Expand Up @@ -121,3 +125,52 @@ func TestDuplicateStmtBuilderUniqueConstraints(t *testing.T) {
})
}
}

func TestDuplicateStmtBuilderForeignKeyConstraints(t *testing.T) {
d := &duplicatorStmtBuilder{table}
for name, testCases := range map[string]struct {
columns []string
expectedStmts []string
}{
"duplicate single column with no FK constraint": {
columns: []string{"description"},
expectedStmts: []string{},
},
"single-column FK with single column duplicated": {
columns: []string{"city"},
expectedStmts: []string{
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_city" FOREIGN KEY ("_pgroll_new_city") REFERENCES "cities" ("id") ON DELETE NO ACTION`,
},
},
"single-column FK with multiple columns duplicated": {
columns: []string{"city", "description"},
expectedStmts: []string{
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_city" FOREIGN KEY ("_pgroll_new_city") REFERENCES "cities" ("id") ON DELETE NO ACTION`,
},
},
"multi-column FK with single column duplicated": {
columns: []string{"name"},
expectedStmts: []string{
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`,
},
},
"multi-column FK with multiple unrelated column duplicated": {
columns: []string{"name", "description"},
expectedStmts: []string{
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`,
},
},
"multi-column FK with multiple columns": {
columns: []string{"name", "nick"},
expectedStmts: []string{`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "_pgroll_new_nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`},
},
} {
t.Run(name, func(t *testing.T) {
stmts := d.duplicateForeignKeyConstraints(nil, testCases.columns...)
assert.Equal(t, len(testCases.expectedStmts), len(stmts))
for _, stmt := range stmts {
assert.Contains(t, testCases.expectedStmts, stmt)
}
})
}
}

0 comments on commit 9fe8d9b

Please sign in to comment.