diff --git a/schema/relationship.go b/schema/relationship.go index 0535bba40..d7c98b21c 100644 --- a/schema/relationship.go +++ b/schema/relationship.go @@ -631,6 +631,7 @@ type Constraint struct { References []*Field OnDelete string OnUpdate string + Deferrable string } func (constraint *Constraint) GetName() string { return constraint.Name } @@ -645,6 +646,10 @@ func (constraint *Constraint) Build() (sql string, vars []interface{}) { sql += " ON UPDATE " + constraint.OnUpdate } + if constraint.Deferrable != "" { + sql += " DEFERRABLE " + constraint.Deferrable + } + foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys)) for _, field := range constraint.ForeignKeys { foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) @@ -700,10 +705,11 @@ func (rel *Relationship) ParseConstraint() *Constraint { } constraint := Constraint{ - Name: name, - Field: rel.Field, - OnUpdate: settings["ONUPDATE"], - OnDelete: settings["ONDELETE"], + Name: name, + Field: rel.Field, + OnUpdate: settings["ONUPDATE"], + OnDelete: settings["ONDELETE"], + Deferrable: strings.ToUpper(settings["DEFERRABLE"]), } for _, ref := range rel.References { diff --git a/schema/relationship_test.go b/schema/relationship_test.go index c706ac84c..9981b4e61 100644 --- a/schema/relationship_test.go +++ b/schema/relationship_test.go @@ -1,6 +1,7 @@ package schema_test import ( + "strings" "sync" "testing" "time" @@ -1041,3 +1042,32 @@ func TestDataRace(t *testing.T) { }() } } + +func TestDeferrable(t *testing.T) { + type Profile struct { + gorm.Model + Name string + User User + UserId uint + } + + type User struct { + gorm.Model + Profile Profile `gorm:"constraint:deferrable:initially deferred"` + } + + s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) + if err != nil { + t.Fatalf("failed to parse schema") + } + + constraint := s.Relationships.Relations["Profile"].ParseConstraint() + if constraint.Deferrable != "INITIALLY DEFERRED" { + t.Fatalf("expected deferrable INITIALLY DEFERRED, got %v", constraint.Deferrable) + } + + sql, _ := constraint.Build() + if !strings.Contains(sql, "DEFERRABLE INITIALLY DEFERRED") { + t.Fatalf("expected deferrable in sql, got %v", sql) + } +}