diff --git a/session/session.go b/session/session.go index 84ff7e4eec424..e1581d5ed4074 100644 --- a/session/session.go +++ b/session/session.go @@ -473,6 +473,9 @@ func (s *session) doCommit(ctx context.Context) error { if err != nil { return err } + if err = s.removeTempTableFromBuffer(); err != nil { + return err + } // mockCommitError and mockGetTSErrorInRetry use to test PR #8743. failpoint.Inject("mockCommitError", func(val failpoint.Value) { @@ -526,29 +529,40 @@ func (s *session) doCommit(ctx context.Context) error { s.GetSessionVars().TxnCtx.IsExplicit && s.GetSessionVars().GuaranteeLinearizability) } - // Filter out the temporary table key-values. - if tables := s.sessionVars.TxnCtx.GlobalTemporaryTables; tables != nil { - memBuffer := s.txn.GetMemBuffer() - for tid := range tables { - seekKey := tablecodec.EncodeTablePrefix(tid) - endKey := tablecodec.EncodeTablePrefix(tid + 1) - iter, err := memBuffer.Iter(seekKey, endKey) - if err != nil { + return s.txn.Commit(tikvutil.SetSessionID(ctx, s.GetSessionVars().ConnectionID)) +} + +// removeTempTableFromBuffer filters out the temporary table key-values. +func (s *session) removeTempTableFromBuffer() error { + tables := s.GetSessionVars().TxnCtx.GlobalTemporaryTables + if len(tables) == 0 { + return nil + } + memBuffer := s.txn.GetMemBuffer() + // Reset and new an empty stage buffer. + defer func() { + s.txn.cleanup() + }() + for tid := range tables { + seekKey := tablecodec.EncodeTablePrefix(tid) + endKey := tablecodec.EncodeTablePrefix(tid + 1) + iter, err := memBuffer.Iter(seekKey, endKey) + if err != nil { + return err + } + for iter.Valid() && iter.Key().HasPrefix(seekKey) { + if err = memBuffer.Delete(iter.Key()); err != nil { return err } - for iter.Valid() && iter.Key().HasPrefix(seekKey) { - if err = memBuffer.Delete(iter.Key()); err != nil { - return errors.Trace(err) - } - s.txn.UpdateEntriesCountAndSize() - if err = iter.Next(); err != nil { - return errors.Trace(err) - } + s.txn.UpdateEntriesCountAndSize() + if err = iter.Next(); err != nil { + return err } } } - - return s.txn.Commit(tikvutil.SetSessionID(ctx, s.GetSessionVars().ConnectionID)) + // Flush to the root membuffer. + s.txn.flushStmtBuf() + return nil } // errIsNoisy is used to filter DUPLCATE KEY errors. diff --git a/session/session_test.go b/session/session_test.go index a897cb7db07f3..b7cfecc9c5f4c 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -789,6 +789,49 @@ func (s *testSessionSuite) TestRetryUnion(c *C) { c.Assert(err, ErrorMatches, ".*can not retry select for update statement") } +func (s *testSessionSuite) TestRetryGlobalTempTable(c *C) { + tk := testkit.NewTestKitWithInit(c, s.store) + tk.MustExec("drop table if exists normal_table") + tk.MustExec("create table normal_table(a int primary key, b int)") + defer tk.MustExec("drop table if exists normal_table") + tk.MustExec("drop table if exists temp_table") + tk.MustExec("create global temporary table temp_table(a int primary key, b int) on commit delete rows") + defer tk.MustExec("drop table if exists temp_table") + + // insert select + tk.MustExec("set tidb_disable_txn_auto_retry = 0") + tk.MustExec("insert normal_table value(100, 100)") + tk.MustExec("set @@autocommit = 0") + // used to make conflicts + tk.MustExec("update normal_table set b=b+1 where a=100") + tk.MustExec("insert temp_table value(1, 1)") + tk.MustExec("insert normal_table select * from temp_table") + c.Assert(session.GetHistory(tk.Se).Count(), Equals, 3) + + // try to conflict with tk + tk1 := testkit.NewTestKitWithInit(c, s.store) + tk1.MustExec("update normal_table set b=b+1 where a=100") + + // It will retry internally. + tk.MustExec("commit") + tk.MustQuery("select a, b from normal_table order by a").Check(testkit.Rows("1 1", "100 102")) + tk.MustQuery("select a, b from temp_table order by a").Check(testkit.Rows()) + + // update multi-tables + tk.MustExec("update normal_table set b=b+1 where a=100") + tk.MustExec("insert temp_table value(1, 2)") + // before update: normal_table=(1 1) (100 102), temp_table=(1 2) + tk.MustExec("update normal_table, temp_table set normal_table.b=temp_table.b where normal_table.a=temp_table.a") + c.Assert(session.GetHistory(tk.Se).Count(), Equals, 3) + + // try to conflict with tk + tk1.MustExec("update normal_table set b=b+1 where a=100") + + // It will retry internally. + tk.MustExec("commit") + tk.MustQuery("select a, b from normal_table order by a").Check(testkit.Rows("1 2", "100 104")) +} + func (s *testSessionSuite) TestRetryShow(c *C) { tk := testkit.NewTestKitWithInit(c, s.store) tk.MustExec("set @@autocommit = 0")