From 2e1a9fa48a7e4a2f629f722d32845a29edaab425 Mon Sep 17 00:00:00 2001 From: Ting-Lan Wang Date: Wed, 10 Sep 2025 19:11:04 -0400 Subject: [PATCH 1/2] Add support for 'OnUpdate' --- oracle/migrator.go | 126 +++++++++++++++++++++++++++++++++++++ tests/associations_test.go | 6 +- tests/go.mod | 4 +- tests/migrate_test.go | 56 +++++++++++++++++ 4 files changed, 186 insertions(+), 6 deletions(-) diff --git a/oracle/migrator.go b/oracle/migrator.go index 9686a5f..83c63cb 100644 --- a/oracle/migrator.go +++ b/oracle/migrator.go @@ -145,6 +145,17 @@ func (m Migrator) CreateTable(values ...interface{}) error { } if constraint := rel.ParseConstraint(); constraint != nil { if constraint.Schema == stmt.Schema { + // Oracle doesn’t support OnUpdate on foreign keys. + // Use a trigger instead to propagate the update to the child table instead. + if len(constraint.References) > 0 && constraint.OnUpdate != "" { + constraint.OnUpdate = "" + defer func(tx *gorm.DB, table string, constraint *schema.Constraint) { + if err == nil { + err = m.createUpadateCascadeTrigger(tx, constraint) + } + }(tx, stmt.Table, constraint) + } + // If the same set of foreign keys already references the parent column, // remove duplicates to avoid ORA-02274: duplicate referential constraint specifications var foreignKeys []string @@ -399,6 +410,32 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { return columnTypes, execErr } +// CreateConstraint creates constraint based on the given 'value' and 'name' +func (m Migrator) CreateConstraint(value interface{}, name string) error { + return m.RunWithValue(value, func(stmt *gorm.Statement) error { + constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) + if constraint != nil { + if c, ok := constraint.(*schema.Constraint); ok { + // Oracle doesn’t support OnUpdate on foreign keys. + // Use a trigger instead to propagate the update to the child table instead. + if len(c.References) > 0 && c.OnUpdate != "" { + c.OnUpdate = "" + constraint = c + m.createUpadateCascadeTrigger(m.DB, c) + } + } + + vars := []interface{}{clause.Table{Name: table}} + if stmt.TableExpr != nil { + vars[0] = stmt.TableExpr + } + sql, values := constraint.Build() + return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error + } + return nil + }) +} + // HasConstraint checks whether the table for the given `value` contains the specified constraint `name` func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 @@ -418,6 +455,33 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { return count > 0 } +// DropConstraint drops constraint based on the given 'value' and 'name' +func (m Migrator) DropConstraint(value interface{}, name string) error { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { + + constraint, _ := m.GuessConstraintInterfaceAndTable(stmt, name) + + if c, ok := constraint.(*schema.Constraint); ok && c != nil { + if len(c.References) > 0 && c.OnUpdate != "" { + for i, fk := range c.ForeignKeys { + triggerName := m.FkTriggerName( + c.ReferenceSchema.Table, + c.References[i].DBName, + c.Schema.Table, + fk.DBName, + ) + return m.DB.Exec("DROP TRIGGER ?", clause.Column{Name: triggerName}).Error + } + } + } + return nil + }); err != nil { + return err + } + + return m.Migrator.DropConstraint(value, name) +} + // DropIndex drops the index with the specified `name` from the table associated with `value` func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { @@ -569,3 +633,65 @@ func (m Migrator) isNumeric(s string) bool { _, err := strconv.ParseFloat(s, 64) return err == nil } + +func (m Migrator) FkTriggerName(refTable string, refField string, table string, field string) string { + return fmt.Sprintf("fk_trigger_%s_%s_%s_%s", refTable, refField, table, field) +} + +// Creates a trigger to cascade the update to the child table +func (m Migrator) createUpadateCascadeTrigger(tx *gorm.DB, constraint *schema.Constraint) error { + for i, fk := range constraint.ForeignKeys { + var ( + tmpBuilder strings.Builder + plsqlBuilder strings.Builder + parentTable string = constraint.ReferenceSchema.Table + parentField string = constraint.References[i].DBName + table string = constraint.Schema.Table + field string = fk.DBName + + triggerName string = m.FkTriggerName(parentTable, parentField, table, field) + + quotedParentTable string + quotedParentField string + quotedTable string + quotedField string + quotedTriggerName string + ) + + // Initialize quoted variables according to the driver’s quoting rules + writeQuotedIdentifier(&tmpBuilder, parentTable) + quotedParentTable = tmpBuilder.String() + tmpBuilder.Reset() + + writeQuotedIdentifier(&tmpBuilder, parentField) + quotedParentField = tmpBuilder.String() + tmpBuilder.Reset() + + writeQuotedIdentifier(&tmpBuilder, table) + quotedTable = tmpBuilder.String() + tmpBuilder.Reset() + + writeQuotedIdentifier(&tmpBuilder, field) + quotedField = tmpBuilder.String() + tmpBuilder.Reset() + + writeQuotedIdentifier(&tmpBuilder, triggerName) + quotedTriggerName = tmpBuilder.String() + tmpBuilder.Reset() + + // Start PL/SQL block + plsqlBuilder.WriteString("CREATE OR REPLACE TRIGGER " + quotedTriggerName + "\n") + plsqlBuilder.WriteString("AFTER UPDATE OF " + quotedParentField + " ON " + quotedParentTable + "\n") + plsqlBuilder.WriteString("FOR EACH ROW\n") + plsqlBuilder.WriteString("BEGIN\n") + plsqlBuilder.WriteString(" UPDATE " + quotedTable + "\n") + plsqlBuilder.WriteString(" SET " + quotedField + " = :NEW." + quotedParentField + "\n") + plsqlBuilder.WriteString(" WHERE " + quotedField + " = :OLD." + quotedParentField) + plsqlBuilder.WriteString(";\n") + plsqlBuilder.WriteString("END;") + if err := tx.Exec(plsqlBuilder.String()).Error; err != nil { + return err + } + } + return nil +} diff --git a/tests/associations_test.go b/tests/associations_test.go index 50ef99a..59cd185 100644 --- a/tests/associations_test.go +++ b/tests/associations_test.go @@ -112,7 +112,6 @@ func TestAssociationNotNullClear(t *testing.T) { } func TestForeignKeyConstraints(t *testing.T) { - t.Skip() type Profile struct { ID uint Name string @@ -121,7 +120,7 @@ func TestForeignKeyConstraints(t *testing.T) { type Member struct { ID uint - Refer uint `gorm:"uniqueIndex"` + Refer uint `gorm:"unique"` Name string Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:MemberID;References:Refer"` } @@ -168,11 +167,10 @@ func TestForeignKeyConstraints(t *testing.T) { } func TestForeignKeyConstraintsBelongsTo(t *testing.T) { - t.Skip() type Profile struct { ID uint Name string - Refer uint `gorm:"uniqueIndex"` + Refer uint `gorm:"unique"` } type Member struct { diff --git a/tests/go.mod b/tests/go.mod index 5ed2155..99110cb 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -5,7 +5,8 @@ go 1.24.4 require gorm.io/gorm v1.30.0 require ( - github.com/oracle-samples/gorm-oracle v0.0.1 + github.com/godror/godror v0.49.0 + github.com/oracle-samples/gorm-oracle v0.1.0 github.com/stretchr/testify v1.10.0 ) @@ -13,7 +14,6 @@ require ( github.com/VictoriaMetrics/easyproto v0.1.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-logfmt/logfmt v0.6.0 // indirect - github.com/godror/godror v0.49.0 // indirect github.com/godror/knownpb v0.3.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 5564837..7c416bb 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1721,3 +1721,59 @@ func TestAutoMigrateDecimal(t *testing.T) { decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql) } } + +func TestMigrateOnUpdateConstraint(t *testing.T) { + type Owner struct { + ID int + Name string + } + + type Pen struct { + gorm.Model + OwnerID int + Owner Owner `gorm:"constraint:OnUpdate:CASCADE,OnDelete:SET NULL;"` + } + + DB.Migrator().DropTable(&Pen{}, &Owner{}) + + // Verify the trigger is created using CreateTable() + if err := DB.Migrator().CreateTable(&Owner{}, &Pen{}); err != nil { + t.Fatalf("Failed to create table, got error: %v", err) + } + + triggerName := "fk_trigger_owners_id_pens_owner_id" + + var count int + DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerName).Scan(&count) + if count != 1 { + t.Errorf("Should find the trigger %s", triggerName) + } + + // Verify the trigger is created using CreateConstraint() + constraintName := "fk_pens_owner" + if err := DB.Migrator().DropConstraint(&Pen{}, constraintName); err != nil { + t.Errorf("failed to drop constraint %v, got error %v", constraintName, err) + } + + if err := DB.Migrator().CreateConstraint(&Pen{}, constraintName); err != nil { + t.Errorf("failed to create constraint %v, got error %v", constraintName, err) + } + + DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerName).Scan(&count) + if count != 1 { + t.Errorf("Should find the trigger %s", triggerName) + } + + // Verify the trigger works + user := Pen{Owner: Owner{ID: 1, Name: "John"}} + DB.Create(&user) + + DB.Model(user.Owner).Update("id", 100) + + var user2 Pen + if err := DB.First(&user2, "\"id\" = ?", user.ID).Error; err != nil { + panic(fmt.Errorf("failed to find member, got error: %v", err)) + } else if user2.OwnerID != 100 { + panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 100, user2.OwnerID)) + } +} From c2a262cecb055fa77712d4ff930b5da0de7ae83c Mon Sep 17 00:00:00 2001 From: Ting-Lan Wang Date: Thu, 11 Sep 2025 14:09:42 -0400 Subject: [PATCH 2/2] Add support for SET NULL and SET DEFAULT --- README.md | 43 +++++++++++++++++ oracle/common.go | 6 +++ oracle/migrator.go | 103 ++++++++++++++++++++--------------------- tests/migrate_test.go | 104 ++++++++++++++++++++++++++++++------------ 4 files changed, 177 insertions(+), 79 deletions(-) diff --git a/README.md b/README.md index 46519f7..f45bd01 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,49 @@ func main() { } ``` +## Documentation + +### OnUpdate Foreign Key Constraint + +Since Oracle doesn’t support `ON UPDATE` in foreign keys, the driver simulates it using **triggers**. + +When a field has a constraint tagged with `OnUpdate`, the driver: + +1. Skips generating the unsupported `ON UPDATE` clause in the foreign key definition. +2. Creates a trigger on the parent table that automatically cascades updates to the child table(s) whenever the referenced column is changed. + +The `OnUpdate` tag accepts the following values (case-insensitive): `CASCADE`, `SET NULL`, and `SET DEFAULT`. + +Take the following struct for an example: + +```go +type Profile struct { + ID uint + Name string + Refer uint +} + +type Member struct { + ID uint + Name string + ProfileID uint + Profile Profile `gorm:"Constraint:OnUpdate:CASCADE"` +} +``` + +Trigger SQL created by the driver when migrating: + +```sql +CREATE OR REPLACE TRIGGER "fk_trigger_profiles_id_members_profile_id" +AFTER UPDATE OF "id" ON "profiles" +FOR EACH ROW +BEGIN + UPDATE "members" + SET "profile_id" = :NEW."id" + WHERE "profile_id" = :OLD."id"; +END; +``` + ## Contributing This project welcomes contributions from the community. Before submitting a pull request, please [review our contribution guide](./CONTRIBUTING.md) diff --git a/oracle/common.go b/oracle/common.go index 42571e6..4d8932c 100644 --- a/oracle/common.go +++ b/oracle/common.go @@ -424,6 +424,12 @@ func writeQuotedIdentifier(builder *strings.Builder, identifier string) { builder.WriteByte('"') } +func QuoteIdentifier(identifier string) string { + var builder strings.Builder + writeQuotedIdentifier(&builder, identifier) + return builder.String() +} + // writeTableRecordCollectionDecl writes the PL/SQL declarations needed to // define a custom record type and a collection of that record type, // based on the schema of the given table. diff --git a/oracle/migrator.go b/oracle/migrator.go index 83c63cb..e57a0c3 100644 --- a/oracle/migrator.go +++ b/oracle/migrator.go @@ -148,12 +148,14 @@ func (m Migrator) CreateTable(values ...interface{}) error { // Oracle doesn’t support OnUpdate on foreign keys. // Use a trigger instead to propagate the update to the child table instead. if len(constraint.References) > 0 && constraint.OnUpdate != "" { - constraint.OnUpdate = "" - defer func(tx *gorm.DB, table string, constraint *schema.Constraint) { + defer func(tx *gorm.DB, table string, constraint *schema.Constraint, onUpdate string) { if err == nil { + // retore the OnUpdate value + constraint.OnUpdate = onUpdate err = m.createUpadateCascadeTrigger(tx, constraint) } - }(tx, stmt.Table, constraint) + }(tx, stmt.Table, constraint, constraint.OnUpdate) + constraint.OnUpdate = "" } // If the same set of foreign keys already references the parent column, @@ -419,9 +421,9 @@ func (m Migrator) CreateConstraint(value interface{}, name string) error { // Oracle doesn’t support OnUpdate on foreign keys. // Use a trigger instead to propagate the update to the child table instead. if len(c.References) > 0 && c.OnUpdate != "" { + m.createUpadateCascadeTrigger(m.DB, c) c.OnUpdate = "" constraint = c - m.createUpadateCascadeTrigger(m.DB, c) } } @@ -640,58 +642,57 @@ func (m Migrator) FkTriggerName(refTable string, refField string, table string, // Creates a trigger to cascade the update to the child table func (m Migrator) createUpadateCascadeTrigger(tx *gorm.DB, constraint *schema.Constraint) error { + onUpdate := strings.TrimSpace(strings.ToLower(constraint.OnUpdate)) + if onUpdate != "cascade" && onUpdate != "set null" && onUpdate != "set default" { + return nil + } + + parentTable := constraint.ReferenceSchema.Table + quotedParentTable := QuoteIdentifier(parentTable) + table := constraint.Schema.Table + quotedTable := QuoteIdentifier(table) + for i, fk := range constraint.ForeignKeys { - var ( - tmpBuilder strings.Builder - plsqlBuilder strings.Builder - parentTable string = constraint.ReferenceSchema.Table - parentField string = constraint.References[i].DBName - table string = constraint.Schema.Table - field string = fk.DBName - - triggerName string = m.FkTriggerName(parentTable, parentField, table, field) - - quotedParentTable string - quotedParentField string - quotedTable string - quotedField string - quotedTriggerName string + parentField := constraint.References[i].DBName + quotedParentField := QuoteIdentifier(parentField) + field := fk.DBName + quotedField := QuoteIdentifier(field) + triggerName := m.FkTriggerName(parentTable, parentField, table, field) + quotedTriggerName := QuoteIdentifier(triggerName) + + var updateValue string + switch onUpdate { + case "cascade": + updateValue = ":NEW." + quotedParentField + case "set null": + updateValue = "NULL" + case "set default": + updateValue = "DEFAULT" + } + + plsql := fmt.Sprintf( + `CREATE OR REPLACE TRIGGER %s +AFTER UPDATE OF %s ON %s +FOR EACH ROW +BEGIN + UPDATE %s + SET %s = %s + WHERE %s = :OLD.%s; +END;`, + quotedTriggerName, + quotedParentField, + quotedParentTable, + quotedTable, + quotedField, + updateValue, + quotedField, + quotedParentField, ) - // Initialize quoted variables according to the driver’s quoting rules - writeQuotedIdentifier(&tmpBuilder, parentTable) - quotedParentTable = tmpBuilder.String() - tmpBuilder.Reset() - - writeQuotedIdentifier(&tmpBuilder, parentField) - quotedParentField = tmpBuilder.String() - tmpBuilder.Reset() - - writeQuotedIdentifier(&tmpBuilder, table) - quotedTable = tmpBuilder.String() - tmpBuilder.Reset() - - writeQuotedIdentifier(&tmpBuilder, field) - quotedField = tmpBuilder.String() - tmpBuilder.Reset() - - writeQuotedIdentifier(&tmpBuilder, triggerName) - quotedTriggerName = tmpBuilder.String() - tmpBuilder.Reset() - - // Start PL/SQL block - plsqlBuilder.WriteString("CREATE OR REPLACE TRIGGER " + quotedTriggerName + "\n") - plsqlBuilder.WriteString("AFTER UPDATE OF " + quotedParentField + " ON " + quotedParentTable + "\n") - plsqlBuilder.WriteString("FOR EACH ROW\n") - plsqlBuilder.WriteString("BEGIN\n") - plsqlBuilder.WriteString(" UPDATE " + quotedTable + "\n") - plsqlBuilder.WriteString(" SET " + quotedField + " = :NEW." + quotedParentField + "\n") - plsqlBuilder.WriteString(" WHERE " + quotedField + " = :OLD." + quotedParentField) - plsqlBuilder.WriteString(";\n") - plsqlBuilder.WriteString("END;") - if err := tx.Exec(plsqlBuilder.String()).Error; err != nil { + if err := tx.Exec(plsql).Error; err != nil { return err } } + return nil } diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 7c416bb..8685a16 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -1728,52 +1728,100 @@ func TestMigrateOnUpdateConstraint(t *testing.T) { Name string } - type Pen struct { + type Pen1 struct { gorm.Model OwnerID int - Owner Owner `gorm:"constraint:OnUpdate:CASCADE,OnDelete:SET NULL;"` + Owner Owner `gorm:"constraint:OnUpdate:CASCADE;"` } - DB.Migrator().DropTable(&Pen{}, &Owner{}) + type Pen2 struct { + gorm.Model + OwnerID int `gorm:"default: 18"` + Owner Owner `gorm:"constraint:OnUpdate:SET DEFAULT;"` + } - // Verify the trigger is created using CreateTable() - if err := DB.Migrator().CreateTable(&Owner{}, &Pen{}); err != nil { - t.Fatalf("Failed to create table, got error: %v", err) + type Pen3 struct { + gorm.Model + OwnerID int + Owner Owner `gorm:"constraint:OnUpdate:SET NULL;"` } - triggerName := "fk_trigger_owners_id_pens_owner_id" + DB.Migrator().DropTable(&Owner{}, &Pen1{}, &Pen2{}, &Pen3{}) - var count int - DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerName).Scan(&count) - if count != 1 { - t.Errorf("Should find the trigger %s", triggerName) + // Test 1: Verify the trigger is created using CreateTable() + if err := DB.Migrator().CreateTable(&Owner{}, &Pen1{}, &Pen2{}, &Pen3{}); err != nil { + t.Fatalf("Failed to create table, got error: %v", err) } - // Verify the trigger is created using CreateConstraint() - constraintName := "fk_pens_owner" - if err := DB.Migrator().DropConstraint(&Pen{}, constraintName); err != nil { - t.Errorf("failed to drop constraint %v, got error %v", constraintName, err) + triggerNames := []string{ + "fk_trigger_owners_id_pen1_owner_id", + "fk_trigger_owners_id_pen2_owner_id", + "fk_trigger_owners_id_pen3_owner_id", } - if err := DB.Migrator().CreateConstraint(&Pen{}, constraintName); err != nil { - t.Errorf("failed to create constraint %v, got error %v", constraintName, err) + for _, triggerName := range triggerNames { + var count int + DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerName).Scan(&count) + if count != 1 { + t.Errorf("Should find the trigger %s", triggerName) + } } - DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerName).Scan(&count) - if count != 1 { - t.Errorf("Should find the trigger %s", triggerName) + // Test 2: Verify the trigger is created using CreateConstraint() + penStructs := []interface{}{&Pen1{}, &Pen2{}, &Pen3{}} + constraintNames := []string{"fk_pen1_owner", "fk_pen2_owner", "fk_pen3_owner"} + for i := range 3 { + if err := DB.Migrator().DropConstraint(penStructs[i], constraintNames[i]); err != nil { + t.Errorf("failed to drop constraint %v, got error %v", constraintNames[i], err) + } + + if err := DB.Migrator().CreateConstraint(penStructs[i], constraintNames[i]); err != nil { + t.Errorf("failed to create constraint %v, got error %v", constraintNames[i], err) + } + + var count int + DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerNames[i]).Scan(&count) + if count != 1 { + t.Errorf("Should find the trigger %s", triggerNames[i]) + } } - // Verify the trigger works - user := Pen{Owner: Owner{ID: 1, Name: "John"}} - DB.Create(&user) + // Test 3: Verify each trigger work + pen1 := Pen1{Owner: Owner{ID: 1, Name: "John"}} + DB.Create(&pen1) + DB.Model(pen1.Owner).Update("id", 100) + + var updatedPen1 Pen1 + if err := DB.First(&updatedPen1, "\"id\" = ?", pen1.ID).Error; err != nil { + panic(fmt.Errorf("failed to find member, got error: %v", err)) + } else if updatedPen1.OwnerID != 100 { + panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 100, updatedPen1.OwnerID)) + } + + pen2 := Pen2{Owner: Owner{ID: 2, Name: "Mary"}} + DB.Create(&pen2) + // When the ID in the owners table is updated, the primary key in pen2 (owner_id column) + // is set to its default value (18). To avoid violating the foreign key constraint in pen2, + // we need to insert this record into the owners table in advance. + owner := Owner{ID: 18, Name: "MaryBackup"} + DB.Create(&owner) + DB.Model(pen2.Owner).Update("id", 200) + + var updatedPen2 Pen2 + if err := DB.First(&updatedPen2, "\"id\" = ?", pen2.ID).Error; err != nil { + panic(fmt.Errorf("failed to find member, got error: %v", err)) + } else if updatedPen2.OwnerID != 18 { + panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 18, updatedPen2.OwnerID)) + } - DB.Model(user.Owner).Update("id", 100) + pen3 := Pen3{Owner: Owner{ID: 3, Name: "Jane"}} + DB.Create(&pen3) + DB.Model(pen3.Owner).Update("id", 300) - var user2 Pen - if err := DB.First(&user2, "\"id\" = ?", user.ID).Error; err != nil { + var updatedPen3 Pen3 + if err := DB.First(&updatedPen3, "\"id\" = ?", pen3.ID).Error; err != nil { panic(fmt.Errorf("failed to find member, got error: %v", err)) - } else if user2.OwnerID != 100 { - panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 100, user2.OwnerID)) + } else if updatedPen3.OwnerID != 0 { + panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 0, updatedPen3.OwnerID)) } }