Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for creating foreign key constraints using create_constraint #471

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1102,7 +1102,7 @@ Example **create table** migrations:

A create constraint operation adds a new constraint to an existing table.

Only `UNIQUE` and `CHECK` constraints are supported.
`UNIQUE`, `CHECK` and `FOREIGN KEY` constraints are supported.

Required fields: `name`, `table`, `type`, `up`, `down`.

Expand All @@ -1114,7 +1114,14 @@ Required fields: `name`, `table`, `type`, `up`, `down`.
"table": "name of table",
"name": "my_unique_constraint",
"columns": ["col1", "col2"],
"type": "unique"
"type": "unique"| "check" | "foreign_key",
"check": "SQL expression for CHECK constraint",
"references": {
"name": "name of foreign key reference",
"table": "name of referenced table",
"columns": "[names of referenced columns]",
"on_delete": "ON DELETE behaviour, can be CASCADE, SET NULL, RESTRICT, or NO ACTION. Default is NO ACTION",
},
"up": {
"col1": "col1 || random()",
"col2": "col2 || random()"
Expand All @@ -1131,7 +1138,7 @@ Example **create constraint** migrations:

* [44_add_table_unique_constraint.json](../examples/44_add_table_unique_constraint.json)
* [45_add_table_check_constraint.json](../examples/45_add_table_check_constraint.json)

* [46_add_table_foreign_key_constraint.json](../examples/46_add_table_foreign_key_constraint.json)

### Drop column

Expand Down
1 change: 1 addition & 0 deletions examples/.ledger
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@
44_add_table_unique_constraint.json
45_add_table_check_constraint.json
46_alter_column_drop_default.json
47_add_table_foreign_key_constraint.json
31 changes: 31 additions & 0 deletions examples/47_add_table_foreign_key_constraint.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"name": "47_add_table_foreign_key_constraint",
"operations": [
{
"create_constraint": {
"type": "foreign_key",
"table": "tickets",
"name": "fk_sellers",
"columns": [
"sellers_name",
"sellers_zip"
],
"references": {
"table": "sellers",
"columns": [
"name",
"zip"
]
},
"up": {
"sellers_name": "sellers_name",
"sellers_zip": "sellers_zip"
},
"down": {
"sellers_name": "sellers_name",
"sellers_zip": "sellers_zip"
}
}
}
]
}
64 changes: 32 additions & 32 deletions pkg/migrations/duplicate.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ const (
cCreateUniqueIndexSQL = `CREATE UNIQUE INDEX CONCURRENTLY %s ON %s (%s)`
cSetDefaultSQL = `ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s`
cAlterTableAddCheckConstraintSQL = `ALTER TABLE %s ADD CONSTRAINT %s %s NOT VALID`
cAlterTableAddForeignKeySQL = `ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
)

// NewColumnDuplicator creates a new Duplicator for a column.
Expand Down Expand Up @@ -91,7 +92,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
colNames = append(colNames, name)

// Duplicate the column with the new type
// and check and fk constraints
if sql := d.stmtBuilder.duplicateColumn(c.column, c.asName, c.withoutNotNull, c.withType, d.withoutConstraint); sql != "" {
_, err := d.conn.ExecContext(ctx, sql)
if err != nil {
Expand All @@ -108,6 +108,7 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
}
}

// Duplicate the column's comment
if sql := d.stmtBuilder.duplicateComment(c.column, c.asName); sql != "" {
_, err := d.conn.ExecContext(ctx, sql)
if err != nil {
Expand All @@ -120,7 +121,6 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
// if the check constraint is not valid for the new column type, in which case
// the error is ignored.
for _, sql := range d.stmtBuilder.duplicateCheckConstraints(d.withoutConstraint, colNames...) {
// Update the check constraint expression to use the new column names if any of the columns are duplicated
_, err := d.conn.ExecContext(ctx, sql)
err = errorIgnoringErrorCode(err, undefinedFunctionErrorCode)
if err != nil {
Expand All @@ -132,12 +132,21 @@ func (d *Duplicator) Duplicate(ctx context.Context) error {
// The constraint is duplicated by adding a unique index on the column concurrently.
// The index is converted into a unique constraint on migration completion.
for _, sql := range d.stmtBuilder.duplicateUniqueConstraints(d.withoutConstraint, colNames...) {
// Update the unique constraint columns to use the new column names if any of the columns are duplicated
if _, err := d.conn.ExecContext(ctx, sql); err != nil {
return err
}
}

// Generate SQL to duplicate any foreign key constraints on the columns.
// If the foreign key constraint is not valid for a new column type, the error is ignored.
for _, sql := range d.stmtBuilder.duplicateForeignKeyConstraints(d.withoutConstraint, colNames...) {
_, err := d.conn.ExecContext(ctx, sql)
err = errorIgnoringErrorCode(err, dataTypeMismatchErrorCode)
if err != nil {
return err
}
}

return nil
}

Expand Down Expand Up @@ -175,6 +184,26 @@ func (d *duplicatorStmtBuilder) duplicateUniqueConstraints(withoutConstraint []s
return stmts
}

func (d *duplicatorStmtBuilder) duplicateForeignKeyConstraints(withoutConstraint []string, colNames ...string) []string {
stmts := make([]string, 0, len(d.table.ForeignKeys))
for _, fk := range d.table.ForeignKeys {
if slices.Contains(withoutConstraint, fk.Name) {
continue
}
if duplicatedMember, constraintColumns := d.allConstraintColumns(fk.Columns, colNames...); duplicatedMember {
stmts = append(stmts, fmt.Sprintf(cAlterTableAddForeignKeySQL,
pq.QuoteIdentifier(d.table.Name),
pq.QuoteIdentifier(DuplicationName(fk.Name)),
strings.Join(quoteColumnNames(constraintColumns), ", "),
pq.QuoteIdentifier(fk.ReferencedTable),
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "),
fk.OnDelete,
))
}
}
return stmts
}

// duplicatedConstraintColumns returns a new slice of constraint columns with
// the columns that are duplicated replaced with temporary names.
func (d *duplicatorStmtBuilder) duplicatedConstraintColumns(constraintColumns []string, duplicatedColumns ...string) []string {
Expand Down Expand Up @@ -213,7 +242,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
) string {
const (
cAlterTableSQL = `ALTER TABLE %s ADD COLUMN %s %s`
cAddForeignKeySQL = `ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s`
cAddCheckConstraintSQL = `ADD CONSTRAINT %s %s NOT VALID`
)

Expand All @@ -232,23 +260,6 @@ func (d *duplicatorStmtBuilder) duplicateColumn(
)
}

// Generate SQL to duplicate any foreign key constraints on the column
for _, fk := range d.table.ForeignKeys {
if slices.Contains(withoutConstraint, fk.Name) {
continue
}

if slices.Contains(fk.Columns, column.Name) {
sql += fmt.Sprintf(", "+cAddForeignKeySQL,
pq.QuoteIdentifier(DuplicationName(fk.Name)),
strings.Join(quoteColumnNames(copyAndReplace(fk.Columns, column.Name, asName)), ", "),
pq.QuoteIdentifier(fk.ReferencedTable),
strings.Join(quoteColumnNames(fk.ReferencedColumns), ", "),
fk.OnDelete,
)
}
}

return sql
}

Expand Down Expand Up @@ -295,17 +306,6 @@ func StripDuplicationPrefix(name string) string {
return strings.TrimPrefix(name, "_pgroll_dup_")
}

func copyAndReplace(xs []string, oldValue, newValue string) []string {
ys := slices.Clone(xs)

for i, c := range ys {
if c == oldValue {
ys[i] = newValue
}
}
return ys
}

func errorIgnoringErrorCode(err error, code pq.ErrorCode) error {
pqErr := &pq.Error{}
if ok := errors.As(err, &pqErr); ok {
Expand Down
53 changes: 53 additions & 0 deletions pkg/migrations/duplicate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ var table = &schema.Table{
"new_york_adults": {Name: "new_york_adults", Columns: []string{"city", "age"}, Definition: `"city" = 'New York' AND "age" > 21`},
"different_nick": {Name: "different_nick", Columns: []string{"name", "nick"}, Definition: `"name" != "nick"`},
},
ForeignKeys: map[string]schema.ForeignKey{
"fk_city": {Name: "fk_city", Columns: []string{"city"}, ReferencedTable: "cities", ReferencedColumns: []string{"id"}, OnDelete: "NO ACTION"},
"fk_name_nick": {Name: "fk_name_nick", Columns: []string{"name", "nick"}, ReferencedTable: "users", ReferencedColumns: []string{"name", "nick"}, OnDelete: "CASCADE"},
},
}

func TestDuplicateStmtBuilderCheckConstraints(t *testing.T) {
Expand Down Expand Up @@ -121,3 +125,52 @@ func TestDuplicateStmtBuilderUniqueConstraints(t *testing.T) {
})
}
}

func TestDuplicateStmtBuilderForeignKeyConstraints(t *testing.T) {
d := &duplicatorStmtBuilder{table}
for name, testCases := range map[string]struct {
columns []string
expectedStmts []string
}{
"duplicate single column with no FK constraint": {
columns: []string{"description"},
expectedStmts: []string{},
},
"single-column FK with single column duplicated": {
columns: []string{"city"},
expectedStmts: []string{
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_city" FOREIGN KEY ("_pgroll_new_city") REFERENCES "cities" ("id") ON DELETE NO ACTION`,
},
},
"single-column FK with multiple columns duplicated": {
columns: []string{"city", "description"},
expectedStmts: []string{
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_city" FOREIGN KEY ("_pgroll_new_city") REFERENCES "cities" ("id") ON DELETE NO ACTION`,
},
},
"multi-column FK with single column duplicated": {
columns: []string{"name"},
expectedStmts: []string{
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`,
},
},
"multi-column FK with multiple unrelated column duplicated": {
columns: []string{"name", "description"},
expectedStmts: []string{
`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`,
},
},
"multi-column FK with multiple columns": {
columns: []string{"name", "nick"},
expectedStmts: []string{`ALTER TABLE "test_table" ADD CONSTRAINT "_pgroll_dup_fk_name_nick" FOREIGN KEY ("_pgroll_new_name", "_pgroll_new_nick") REFERENCES "users" ("name", "nick") ON DELETE CASCADE`},
},
} {
t.Run(name, func(t *testing.T) {
stmts := d.duplicateForeignKeyConstraints(nil, testCases.columns...)
assert.Equal(t, len(testCases.expectedStmts), len(stmts))
for _, stmt := range stmts {
assert.Contains(t, testCases.expectedStmts, stmt)
}
})
}
}
2 changes: 1 addition & 1 deletion pkg/migrations/op_add_column.go
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ func (w ColumnSQLWriter) Write(col Column) (string, error) {
sql += fmt.Sprintf(" DEFAULT %s", d)
}
if col.References != nil {
onDelete := "NO ACTION"
onDelete := string(ForeignKeyReferenceOnDeleteNOACTION)
if col.References.OnDelete != "" {
onDelete = strings.ToUpper(string(col.References.OnDelete))
}
Expand Down
2 changes: 1 addition & 1 deletion pkg/migrations/op_add_column_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -626,7 +626,7 @@ func TestAddForeignKeyColumn(t *testing.T) {
Name: "fk_users_id",
Table: "users",
Column: "id",
OnDelete: "CASCADE",
OnDelete: migrations.ForeignKeyReferenceOnDeleteCASCADE,
},
},
},
Expand Down
48 changes: 48 additions & 0 deletions pkg/migrations/op_create_constraint.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ func (o *OpCreateConstraint) Start(ctx context.Context, conn db.DB, latestSchema
return table, o.addUniqueIndex(ctx, conn)
case OpCreateConstraintTypeCheck:
return table, o.addCheckConstraint(ctx, conn)
case OpCreateConstraintTypeForeignKey:
return table, o.addForeignKeyConstraint(ctx, conn)
}

return table, nil
Expand Down Expand Up @@ -97,6 +99,17 @@ func (o *OpCreateConstraint) Complete(ctx context.Context, conn db.DB, tr SQLTra
if err != nil {
return err
}
case OpCreateConstraintTypeForeignKey:
fkOp := &OpSetForeignKey{
Table: o.Table,
References: ForeignKeyReference{
Name: o.Name,
},
}
err := fkOp.Complete(ctx, conn, tr, s)
if err != nil {
return err
}
}

// remove old columns
Expand Down Expand Up @@ -198,6 +211,22 @@ func (o *OpCreateConstraint) Validate(ctx context.Context, s *schema.Schema) err
if o.Check == nil || *o.Check == "" {
return FieldRequiredError{Name: "check"}
}
case OpCreateConstraintTypeForeignKey:
if o.References == nil {
return FieldRequiredError{Name: "references"}
}
table := s.GetTable(o.References.Table)
if table == nil {
return TableDoesNotExistError{Name: o.References.Table}
}
for _, col := range o.References.Columns {
if table.GetColumn(col) == nil {
return ColumnDoesNotExistError{
Table: o.References.Table,
Name: col,
}
}
}
}

return nil
Expand All @@ -223,6 +252,25 @@ func (o *OpCreateConstraint) addCheckConstraint(ctx context.Context, conn db.DB)
return err
}

func (o *OpCreateConstraint) addForeignKeyConstraint(ctx context.Context, conn db.DB) error {
onDelete := "NO ACTION"
if o.References.OnDelete != "" {
onDelete = strings.ToUpper(string(o.References.OnDelete))
}

_, err := conn.ExecContext(ctx,
fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s (%s) ON DELETE %s NOT VALID",
pq.QuoteIdentifier(o.Table),
pq.QuoteIdentifier(o.Name),
strings.Join(quotedTemporaryNames(o.Columns), ","),
pq.QuoteIdentifier(o.References.Table),
strings.Join(quoteColumnNames(o.References.Columns), ","),
onDelete,
))

return err
}

func quotedTemporaryNames(columns []string) []string {
names := make([]string, len(columns))
for i, col := range columns {
Expand Down
Loading