diff --git a/executor/builder.go b/executor/builder.go index 4df7a5f123a1d..f239923202fe2 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -888,9 +888,12 @@ func (b *executorBuilder) buildInsert(v *plannercore.Insert) Executor { b.err = err return nil } - ivs.fkChecks, err = buildFKCheckExecs(b.ctx, ivs.Table, v.FKChecks) - if err != nil { - b.err = err + ivs.fkChecks, b.err = buildFKCheckExecs(b.ctx, ivs.Table, v.FKChecks) + if b.err != nil { + return nil + } + ivs.fkCascades, b.err = b.buildFKCascadeExecs(ivs.Table, v.FKCascades) + if b.err != nil { return nil } diff --git a/executor/fktest/foreign_key_test.go b/executor/fktest/foreign_key_test.go index d1fa78cf5a988..79b36ff7f586f 100644 --- a/executor/fktest/foreign_key_test.go +++ b/executor/fktest/foreign_key_test.go @@ -1999,3 +1999,48 @@ func TestForeignKeyCascadeOnDiffColumnType(t *testing.T) { tk.MustQuery("select cast(id as unsigned) from t1;").Check(testkit.Rows("6")) tk.MustQuery("select id, cast(b as unsigned) from t2;").Check(testkit.Rows("2 6")) } + +func TestForeignKeyOnInsertOnDuplicateUpdate(t *testing.T) { + store := testkit.CreateMockStore(t) + tk := testkit.NewTestKit(t, store) + tk.MustExec("set @@global.tidb_enable_foreign_key=1") + tk.MustExec("set @@foreign_key_checks=1") + tk.MustExec("use test") + tk.MustExec("create table t1 (id int key, name varchar(10));") + tk.MustExec("create table t2 (id int key, pid int, foreign key fk(pid) references t1(id) ON UPDATE CASCADE ON DELETE CASCADE);") + tk.MustExec("insert into t1 values (1, 'a'), (2, 'b')") + tk.MustExec("insert into t2 values (1, 1), (2, 2), (3, 1), (4, 2), (5, null)") + tk.MustExec("insert into t1 values (1, 'aa') on duplicate key update name = 'aa'") + tk.MustQuery("select * from t1 order by id").Check(testkit.Rows("1 aa", "2 b")) + tk.MustQuery("select * from t2 order by id").Check(testkit.Rows("1 1", "2 2", "3 1", "4 2", "5 ")) + tk.MustExec("insert into t1 values (1, 'aaa') on duplicate key update id = 10") + tk.MustQuery("select * from t1 order by id").Check(testkit.Rows("2 b", "10 aa")) + tk.MustQuery("select * from t2 order by id").Check(testkit.Rows("1 10", "2 2", "3 10", "4 2", "5 ")) + // Test in transaction. + tk.MustExec("begin") + tk.MustExec("insert into t1 values (3, 'c')") + tk.MustExec("insert into t2 values (6, 3)") + tk.MustExec("insert into t1 values (2, 'bb'), (3, 'cc') on duplicate key update id =id*10") + tk.MustQuery("select * from t1 order by id").Check(testkit.Rows("10 aa", "20 b", "30 c")) + tk.MustQuery("select * from t2 order by id").Check(testkit.Rows("1 10", "2 20", "3 10", "4 20", "5 ", "6 30")) + tk.MustExec("commit") + tk.MustQuery("select * from t1 order by id").Check(testkit.Rows("10 aa", "20 b", "30 c")) + tk.MustQuery("select * from t2 order by id").Check(testkit.Rows("1 10", "2 20", "3 10", "4 20", "5 ", "6 30")) + tk.MustExec("delete from t1") + tk.MustQuery("select * from t2").Check(testkit.Rows("5 ")) + // Test for cascade update failed. + tk.MustExec("drop table t1, t2") + tk.MustExec("create table t1 (id int key)") + tk.MustExec("create table t2 (id int key, foreign key (id) references t1 (id) on update cascade)") + tk.MustExec("create table t3 (id int key, foreign key (id) references t2(id))") + tk.MustExec("begin") + tk.MustExec("insert into t1 values (1)") + tk.MustExec("insert into t2 values (1)") + tk.MustExec("insert into t3 values (1)") + tk.MustGetDBError("insert into t1 values (1) on duplicate key update id = 2", plannercore.ErrRowIsReferenced2) + require.Equal(t, 0, len(tk.Session().GetSessionVars().TxnCtx.Savepoints)) + tk.MustExec("commit") + tk.MustQuery("select * from t1").Check(testkit.Rows("1")) + tk.MustQuery("select * from t2").Check(testkit.Rows("1")) + tk.MustQuery("select * from t3").Check(testkit.Rows("1")) +} diff --git a/executor/insert.go b/executor/insert.go index 931b8bbf6b480..4f78d0b2809ac 100644 --- a/executor/insert.go +++ b/executor/insert.go @@ -456,10 +456,10 @@ func (e *InsertExec) GetFKChecks() []*FKCheckExec { // GetFKCascades implements WithForeignKeyTrigger interface. func (e *InsertExec) GetFKCascades() []*FKCascadeExec { - return nil + return e.fkCascades } // HasFKCascades implements WithForeignKeyTrigger interface. func (e *InsertExec) HasFKCascades() bool { - return false + return len(e.fkCascades) > 0 } diff --git a/planner/core/common_plans.go b/planner/core/common_plans.go index ea99495705650..39f12a5a1daec 100644 --- a/planner/core/common_plans.go +++ b/planner/core/common_plans.go @@ -362,7 +362,8 @@ type Insert struct { RowLen int - FKChecks []*FKCheck + FKChecks []*FKCheck + FKCascades []*FKCascade } // MemoryUsage return the memory usage of Insert diff --git a/planner/core/foreign_key.go b/planner/core/foreign_key.go index 8c5b03384ae6e..baa271b903443 100644 --- a/planner/core/foreign_key.go +++ b/planner/core/foreign_key.go @@ -82,21 +82,25 @@ func (f *FKCascade) MemoryUsage() (sum int64) { return } -func (p *Insert) buildOnInsertFKChecks(ctx sessionctx.Context, is infoschema.InfoSchema, dbName string) error { +func (p *Insert) buildOnInsertFKTriggers(ctx sessionctx.Context, is infoschema.InfoSchema, dbName string) error { if !ctx.GetSessionVars().ForeignKeyChecks { return nil } tblInfo := p.Table.Meta() fkChecks := make([]*FKCheck, 0, len(tblInfo.ForeignKeys)) + fkCascades := make([]*FKCascade, 0, len(tblInfo.ForeignKeys)) updateCols := p.buildOnDuplicateUpdateColumns() if len(updateCols) > 0 { - referredFKChecks, _, err := buildOnUpdateReferredFKTriggers(is, dbName, tblInfo, updateCols) + referredFKChecks, referredFKCascades, err := buildOnUpdateReferredFKTriggers(is, dbName, tblInfo, updateCols) if err != nil { return err } if len(referredFKChecks) > 0 { fkChecks = append(fkChecks, referredFKChecks...) } + if len(referredFKCascades) > 0 { + fkCascades = append(fkCascades, referredFKCascades...) + } } for _, fk := range tblInfo.ForeignKeys { if fk.Version < 1 { @@ -112,6 +116,7 @@ func (p *Insert) buildOnInsertFKChecks(ctx sessionctx.Context, is infoschema.Inf } } p.FKChecks = fkChecks + p.FKCascades = fkCascades return nil } diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index b9372010b9e2d..a07f3c13d6717 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -3631,7 +3631,7 @@ func (b *PlanBuilder) buildInsert(ctx context.Context, insert *ast.InsertStmt) ( if err != nil { return nil, err } - err = insertPlan.buildOnInsertFKChecks(b.ctx, b.is, tn.DBInfo.Name.L) + err = insertPlan.buildOnInsertFKTriggers(b.ctx, b.is, tn.DBInfo.Name.L) return insertPlan, err }