diff --git a/executor/batch_point_get.go b/executor/batch_point_get.go index 0d8c355319d8a..cf5d919d5c779 100644 --- a/executor/batch_point_get.go +++ b/executor/batch_point_get.go @@ -148,7 +148,9 @@ func (e *BatchPointGetExec) Open(context.Context) error { setResourceGroupTagForTxn(stmtCtx, snapshot) // Avoid network requests for the temporary table. if e.tblInfo.TempTableType == model.TempTableGlobal { - snapshot = globalTemporaryTableSnapshot{snapshot} + snapshot = temporaryTableSnapshot{snapshot, nil} + } else if e.tblInfo.TempTableType == model.TempTableLocal { + snapshot = temporaryTableSnapshot{snapshot, e.ctx.GetSessionVars().TemporaryTableData} } var batchGetter kv.BatchGetter = snapshot if txn.Valid() { @@ -166,14 +168,37 @@ func (e *BatchPointGetExec) Open(context.Context) error { return nil } -// Global temporary table would always be empty, so get the snapshot data of it is meanless. -// globalTemporaryTableSnapshot inherits kv.Snapshot and override the BatchGet methods to return empty. -type globalTemporaryTableSnapshot struct { +// Temporary table would always use memBuffer in session as snapshot. +// temporaryTableSnapshot inherits kv.Snapshot and override the BatchGet methods to return empty. +type temporaryTableSnapshot struct { kv.Snapshot + memBuffer kv.MemBuffer } -func (s globalTemporaryTableSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) { - return make(map[string][]byte), nil +func (s temporaryTableSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) { + values := make(map[string][]byte) + if s.memBuffer == nil { + return values, nil + } + + for _, key := range keys { + val, err := s.memBuffer.Get(ctx, key) + if err == kv.ErrNotExist { + continue + } + + if err != nil { + return nil, err + } + + if len(val) == 0 { + continue + } + + values[string(key)] = val + } + + return values, nil } // Close implements the Executor interface. diff --git a/session/session_test.go b/session/session_test.go index 54a6e1bae9bc4..6800ccd3277ed 100644 --- a/session/session_test.go +++ b/session/session_test.go @@ -4986,6 +4986,7 @@ func (s *testSessionSuite) TestLocalTemporaryTablePointGet(c *C) { tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)") tk.MustExec("insert into tmp1 values(1, 11, 101)") tk.MustExec("insert into tmp1 values(2, 12, 102)") + tk.MustExec("insert into tmp1 values(4, 14, 104)") // check point get out transaction tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101")) @@ -5004,10 +5005,55 @@ func (s *testSessionSuite) TestLocalTemporaryTablePointGet(c *C) { tk.MustQuery("select * from tmp1 where u=13").Check(testkit.Rows("3 13 103")) tk.MustExec("update tmp1 set v=999 where id=2") tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 999")) + tk.MustExec("delete from tmp1 where id=4") + tk.MustQuery("select * from tmp1 where id=4").Check(testkit.Rows()) + tk.MustQuery("select * from tmp1 where u=14").Check(testkit.Rows()) tk.MustExec("commit") // check point get after transaction tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 103")) tk.MustQuery("select * from tmp1 where u=13").Check(testkit.Rows("3 13 103")) tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 999")) + tk.MustQuery("select * from tmp1 where id=4").Check(testkit.Rows()) + tk.MustQuery("select * from tmp1 where u=14").Check(testkit.Rows()) +} + +func (s *testSessionSuite) TestLocalTemporaryTableBatchPointGet(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("set @@tidb_enable_noop_functions=1") + tk.MustExec("use test") + tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)") + tk.MustExec("insert into tmp1 values(1, 11, 101)") + tk.MustExec("insert into tmp1 values(2, 12, 102)") + tk.MustExec("insert into tmp1 values(3, 13, 103)") + tk.MustExec("insert into tmp1 values(4, 14, 104)") + + // check point get out transaction + tk.MustQuery("select * from tmp1 where id in (1, 3)").Check(testkit.Rows("1 11 101", "3 13 103")) + tk.MustQuery("select * from tmp1 where u in (11, 13)").Check(testkit.Rows("1 11 101", "3 13 103")) + tk.MustQuery("select * from tmp1 where id in (1, 3, 5)").Check(testkit.Rows("1 11 101", "3 13 103")) + tk.MustQuery("select * from tmp1 where u in (11, 13, 15)").Check(testkit.Rows("1 11 101", "3 13 103")) + + // check point get in transaction + tk.MustExec("begin") + tk.MustQuery("select * from tmp1 where id in (1, 3)").Check(testkit.Rows("1 11 101", "3 13 103")) + tk.MustQuery("select * from tmp1 where u in (11, 13)").Check(testkit.Rows("1 11 101", "3 13 103")) + tk.MustQuery("select * from tmp1 where id in (1, 3, 5)").Check(testkit.Rows("1 11 101", "3 13 103")) + tk.MustQuery("select * from tmp1 where u in (11, 13, 15)").Check(testkit.Rows("1 11 101", "3 13 103")) + tk.MustExec("insert into tmp1 values(6, 16, 106)") + tk.MustQuery("select * from tmp1 where id in (1, 6)").Check(testkit.Rows("1 11 101", "6 16 106")) + tk.MustQuery("select * from tmp1 where u in (11, 16)").Check(testkit.Rows("1 11 101", "6 16 106")) + tk.MustExec("update tmp1 set v=999 where id=3") + tk.MustQuery("select * from tmp1 where id in (1, 3)").Check(testkit.Rows("1 11 101", "3 13 999")) + tk.MustQuery("select * from tmp1 where u in (11, 13)").Check(testkit.Rows("1 11 101", "3 13 999")) + tk.MustExec("delete from tmp1 where id=4") + tk.MustQuery("select * from tmp1 where id in (1, 4)").Check(testkit.Rows("1 11 101")) + tk.MustQuery("select * from tmp1 where u in (11, 14)").Check(testkit.Rows("1 11 101")) + tk.MustExec("commit") + + // check point get after transaction + tk.MustQuery("select * from tmp1 where id in (1, 3, 6)").Check(testkit.Rows("1 11 101", "3 13 999", "6 16 106")) + tk.MustQuery("select * from tmp1 where u in (11, 13, 16)").Check(testkit.Rows("1 11 101", "3 13 999", "6 16 106")) + tk.MustQuery("select * from tmp1 where id in (1, 4)").Check(testkit.Rows("1 11 101")) + tk.MustQuery("select * from tmp1 where u in (11, 14)").Check(testkit.Rows("1 11 101")) }