Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,49 @@ func main() {
}
```

## Documentation

### OnUpdate Foreign Key Constraint

Since Oracle doesn’t support `ON UPDATE` in foreign keys, the driver simulates it using **triggers**.

When a field has a constraint tagged with `OnUpdate`, the driver:

1. Skips generating the unsupported `ON UPDATE` clause in the foreign key definition.
2. Creates a trigger on the parent table that automatically cascades updates to the child table(s) whenever the referenced column is changed.

The `OnUpdate` tag accepts the following values (case-insensitive): `CASCADE`, `SET NULL`, and `SET DEFAULT`.

Take the following struct for an example:

```go
type Profile struct {
ID uint
Name string
Refer uint
}

type Member struct {
ID uint
Name string
ProfileID uint
Profile Profile `gorm:"Constraint:OnUpdate:CASCADE"`
}
```

Trigger SQL created by the driver when migrating:

```sql
CREATE OR REPLACE TRIGGER "fk_trigger_profiles_id_members_profile_id"
AFTER UPDATE OF "id" ON "profiles"
FOR EACH ROW
BEGIN
UPDATE "members"
SET "profile_id" = :NEW."id"
WHERE "profile_id" = :OLD."id";
END;
```

## Contributing

This project welcomes contributions from the community. Before submitting a pull request, please [review our contribution guide](./CONTRIBUTING.md)
Expand Down
6 changes: 6 additions & 0 deletions oracle/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,12 @@ func writeQuotedIdentifier(builder *strings.Builder, identifier string) {
builder.WriteByte('"')
}

func QuoteIdentifier(identifier string) string {
var builder strings.Builder
writeQuotedIdentifier(&builder, identifier)
return builder.String()
}

// writeTableRecordCollectionDecl writes the PL/SQL declarations needed to
// define a custom record type and a collection of that record type,
// based on the schema of the given table.
Expand Down
127 changes: 127 additions & 0 deletions oracle/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,19 @@ func (m Migrator) CreateTable(values ...interface{}) error {
}
if constraint := rel.ParseConstraint(); constraint != nil {
if constraint.Schema == stmt.Schema {
// Oracle doesn’t support OnUpdate on foreign keys.
// Use a trigger instead to propagate the update to the child table instead.
if len(constraint.References) > 0 && constraint.OnUpdate != "" {
defer func(tx *gorm.DB, table string, constraint *schema.Constraint, onUpdate string) {
if err == nil {
// retore the OnUpdate value
constraint.OnUpdate = onUpdate
err = m.createUpadateCascadeTrigger(tx, constraint)
}
}(tx, stmt.Table, constraint, constraint.OnUpdate)
constraint.OnUpdate = ""
}

// If the same set of foreign keys already references the parent column,
// remove duplicates to avoid ORA-02274: duplicate referential constraint specifications
var foreignKeys []string
Expand Down Expand Up @@ -399,6 +412,32 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
return columnTypes, execErr
}

// CreateConstraint creates constraint based on the given 'value' and 'name'
func (m Migrator) CreateConstraint(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
if constraint != nil {
if c, ok := constraint.(*schema.Constraint); ok {
// Oracle doesn’t support OnUpdate on foreign keys.
// Use a trigger instead to propagate the update to the child table instead.
if len(c.References) > 0 && c.OnUpdate != "" {
m.createUpadateCascadeTrigger(m.DB, c)
c.OnUpdate = ""
constraint = c
}
}

vars := []interface{}{clause.Table{Name: table}}
if stmt.TableExpr != nil {
vars[0] = stmt.TableExpr
}
sql, values := constraint.Build()
return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error
}
return nil
})
}

// HasConstraint checks whether the table for the given `value` contains the specified constraint `name`
func (m Migrator) HasConstraint(value interface{}, name string) bool {
var count int64
Expand All @@ -418,6 +457,33 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool {
return count > 0
}

// DropConstraint drops constraint based on the given 'value' and 'name'
func (m Migrator) DropConstraint(value interface{}, name string) error {
if err := m.RunWithValue(value, func(stmt *gorm.Statement) error {

constraint, _ := m.GuessConstraintInterfaceAndTable(stmt, name)

if c, ok := constraint.(*schema.Constraint); ok && c != nil {
if len(c.References) > 0 && c.OnUpdate != "" {
for i, fk := range c.ForeignKeys {
triggerName := m.FkTriggerName(
c.ReferenceSchema.Table,
c.References[i].DBName,
c.Schema.Table,
fk.DBName,
)
return m.DB.Exec("DROP TRIGGER ?", clause.Column{Name: triggerName}).Error
}
}
}
return nil
}); err != nil {
return err
}

return m.Migrator.DropConstraint(value, name)
}

// DropIndex drops the index with the specified `name` from the table associated with `value`
func (m Migrator) DropIndex(value interface{}, name string) error {
return m.RunWithValue(value, func(stmt *gorm.Statement) error {
Expand Down Expand Up @@ -569,3 +635,64 @@ func (m Migrator) isNumeric(s string) bool {
_, err := strconv.ParseFloat(s, 64)
return err == nil
}

func (m Migrator) FkTriggerName(refTable string, refField string, table string, field string) string {
return fmt.Sprintf("fk_trigger_%s_%s_%s_%s", refTable, refField, table, field)
}

// Creates a trigger to cascade the update to the child table
func (m Migrator) createUpadateCascadeTrigger(tx *gorm.DB, constraint *schema.Constraint) error {
onUpdate := strings.TrimSpace(strings.ToLower(constraint.OnUpdate))
if onUpdate != "cascade" && onUpdate != "set null" && onUpdate != "set default" {
return nil
}

parentTable := constraint.ReferenceSchema.Table
quotedParentTable := QuoteIdentifier(parentTable)
table := constraint.Schema.Table
quotedTable := QuoteIdentifier(table)

for i, fk := range constraint.ForeignKeys {
parentField := constraint.References[i].DBName
quotedParentField := QuoteIdentifier(parentField)
field := fk.DBName
quotedField := QuoteIdentifier(field)
triggerName := m.FkTriggerName(parentTable, parentField, table, field)
quotedTriggerName := QuoteIdentifier(triggerName)

var updateValue string
switch onUpdate {
case "cascade":
updateValue = ":NEW." + quotedParentField
case "set null":
updateValue = "NULL"
case "set default":
updateValue = "DEFAULT"
}

plsql := fmt.Sprintf(
`CREATE OR REPLACE TRIGGER %s
AFTER UPDATE OF %s ON %s
FOR EACH ROW
BEGIN
UPDATE %s
SET %s = %s
WHERE %s = :OLD.%s;
END;`,
quotedTriggerName,
quotedParentField,
quotedParentTable,
quotedTable,
quotedField,
updateValue,
quotedField,
quotedParentField,
)

if err := tx.Exec(plsql).Error; err != nil {
return err
}
}

return nil
}
6 changes: 2 additions & 4 deletions tests/associations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ func TestAssociationNotNullClear(t *testing.T) {
}

func TestForeignKeyConstraints(t *testing.T) {
t.Skip()
type Profile struct {
ID uint
Name string
Expand All @@ -121,7 +120,7 @@ func TestForeignKeyConstraints(t *testing.T) {

type Member struct {
ID uint
Refer uint `gorm:"uniqueIndex"`
Refer uint `gorm:"unique"`
Name string
Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:MemberID;References:Refer"`
}
Expand Down Expand Up @@ -168,11 +167,10 @@ func TestForeignKeyConstraints(t *testing.T) {
}

func TestForeignKeyConstraintsBelongsTo(t *testing.T) {
t.Skip()
type Profile struct {
ID uint
Name string
Refer uint `gorm:"uniqueIndex"`
Refer uint `gorm:"unique"`
}

type Member struct {
Expand Down
4 changes: 2 additions & 2 deletions tests/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@ go 1.24.4
require gorm.io/gorm v1.30.0

require (
github.com/oracle-samples/gorm-oracle v0.0.1
github.com/godror/godror v0.49.0
github.com/oracle-samples/gorm-oracle v0.1.0
github.com/stretchr/testify v1.10.0
)

require (
github.com/VictoriaMetrics/easyproto v0.1.4 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/go-logfmt/logfmt v0.6.0 // indirect
github.com/godror/godror v0.49.0 // indirect
github.com/godror/knownpb v0.3.0 // indirect
github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
Expand Down
104 changes: 104 additions & 0 deletions tests/migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1721,3 +1721,107 @@ func TestAutoMigrateDecimal(t *testing.T) {
decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql)
}
}

func TestMigrateOnUpdateConstraint(t *testing.T) {
type Owner struct {
ID int
Name string
}

type Pen1 struct {
gorm.Model
OwnerID int
Owner Owner `gorm:"constraint:OnUpdate:CASCADE;"`
}

type Pen2 struct {
gorm.Model
OwnerID int `gorm:"default: 18"`
Owner Owner `gorm:"constraint:OnUpdate:SET DEFAULT;"`
}

type Pen3 struct {
gorm.Model
OwnerID int
Owner Owner `gorm:"constraint:OnUpdate:SET NULL;"`
}

DB.Migrator().DropTable(&Owner{}, &Pen1{}, &Pen2{}, &Pen3{})

// Test 1: Verify the trigger is created using CreateTable()
if err := DB.Migrator().CreateTable(&Owner{}, &Pen1{}, &Pen2{}, &Pen3{}); err != nil {
t.Fatalf("Failed to create table, got error: %v", err)
}

triggerNames := []string{
"fk_trigger_owners_id_pen1_owner_id",
"fk_trigger_owners_id_pen2_owner_id",
"fk_trigger_owners_id_pen3_owner_id",
}

for _, triggerName := range triggerNames {
var count int
DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerName).Scan(&count)
if count != 1 {
t.Errorf("Should find the trigger %s", triggerName)
}
}

// Test 2: Verify the trigger is created using CreateConstraint()
penStructs := []interface{}{&Pen1{}, &Pen2{}, &Pen3{}}
constraintNames := []string{"fk_pen1_owner", "fk_pen2_owner", "fk_pen3_owner"}
for i := range 3 {
if err := DB.Migrator().DropConstraint(penStructs[i], constraintNames[i]); err != nil {
t.Errorf("failed to drop constraint %v, got error %v", constraintNames[i], err)
}

if err := DB.Migrator().CreateConstraint(penStructs[i], constraintNames[i]); err != nil {
t.Errorf("failed to create constraint %v, got error %v", constraintNames[i], err)
}

var count int
DB.Raw("SELECT count(*) FROM user_triggers where trigger_name = ?", triggerNames[i]).Scan(&count)
if count != 1 {
t.Errorf("Should find the trigger %s", triggerNames[i])
}
}

// Test 3: Verify each trigger work
pen1 := Pen1{Owner: Owner{ID: 1, Name: "John"}}
DB.Create(&pen1)
DB.Model(pen1.Owner).Update("id", 100)

var updatedPen1 Pen1
if err := DB.First(&updatedPen1, "\"id\" = ?", pen1.ID).Error; err != nil {
panic(fmt.Errorf("failed to find member, got error: %v", err))
} else if updatedPen1.OwnerID != 100 {
panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 100, updatedPen1.OwnerID))
}

pen2 := Pen2{Owner: Owner{ID: 2, Name: "Mary"}}
DB.Create(&pen2)
// When the ID in the owners table is updated, the primary key in pen2 (owner_id column)
// is set to its default value (18). To avoid violating the foreign key constraint in pen2,
// we need to insert this record into the owners table in advance.
owner := Owner{ID: 18, Name: "MaryBackup"}
DB.Create(&owner)
DB.Model(pen2.Owner).Update("id", 200)

var updatedPen2 Pen2
if err := DB.First(&updatedPen2, "\"id\" = ?", pen2.ID).Error; err != nil {
panic(fmt.Errorf("failed to find member, got error: %v", err))
} else if updatedPen2.OwnerID != 18 {
panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 18, updatedPen2.OwnerID))
}

pen3 := Pen3{Owner: Owner{ID: 3, Name: "Jane"}}
DB.Create(&pen3)
DB.Model(pen3.Owner).Update("id", 300)

var updatedPen3 Pen3
if err := DB.First(&updatedPen3, "\"id\" = ?", pen3.ID).Error; err != nil {
panic(fmt.Errorf("failed to find member, got error: %v", err))
} else if updatedPen3.OwnerID != 0 {
panic(fmt.Errorf("company id is not equal: expects: %v, got: %v", 0, updatedPen3.OwnerID))
}
}
Loading