Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: finish insert/point_get for local temporary table #26053

Merged
merged 9 commits into from
Jul 9, 2021
6 changes: 6 additions & 0 deletions executor/point_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,12 @@ func (e *PointGetExecutor) get(ctx context.Context, key kv.Key) ([]byte, error)
if e.tblInfo.TempTableType == model.TempTableGlobal {
return nil, nil
}

// Local temporary table always get snapshot value from session
if e.tblInfo.TempTableType == model.TempTableLocal {
return e.ctx.GetSessionVars().GetTemporaryTableSnapshotValue(ctx, key)
}

lock := e.tblInfo.Lock
if lock != nil && (lock.Tp == model.TableLockRead || lock.Tp == model.TableLockReadOnly) {
if e.ctx.GetSessionVars().EnablePointGetCache {
Expand Down
89 changes: 88 additions & 1 deletion session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package session

import (
"bytes"
"context"
"crypto/tls"
"encoding/json"
Expand Down Expand Up @@ -541,7 +542,93 @@ func (s *session) doCommit(ctx context.Context) error {
s.txn.SetOption(kv.KVFilter, temporaryTableKVFilter(tables))
}

return s.txn.Commit(tikvutil.SetSessionID(ctx, sessVars.ConnectionID))
return s.commitTxnWithTemporaryData(tikvutil.SetSessionID(ctx, sessVars.ConnectionID), &s.txn)
}

func (s *session) commitTxnWithTemporaryData(ctx context.Context, txn kv.Transaction) error {
txnTempTables := s.sessionVars.TxnCtx.TemporaryTables
if len(txnTempTables) == 0 {
return txn.Commit(ctx)
}

sessionData := s.sessionVars.TemporaryTableData
var stage kv.StagingHandle

defer func() {
// stage != kv.InvalidStagingHandle means error occurs, we need to cleanup sessionData
if stage != kv.InvalidStagingHandle {
sessionData.Cleanup(stage)
}
}()

for tblID, tbl := range txnTempTables {
if !tbl.GetModified() {
continue
}

if tbl.GetMeta().TempTableType != model.TempTableLocal {
continue
}

if sessionData == nil {
// Create this txn just for getting a MemBuffer. It's a little tricky
bufferTxn, err := s.store.BeginWithOption(tikv.DefaultStartTSOption().SetStartTS(0))
if err != nil {
return err
}

sessionData = bufferTxn.GetMemBuffer()
tiancaiamao marked this conversation as resolved.
Show resolved Hide resolved
}

if stage == kv.InvalidStagingHandle {
stage = sessionData.Staging()
}

tblPrefix := tablecodec.EncodeTablePrefix(tblID)
endKey := tablecodec.EncodeTablePrefix(tblID + 1)

txnMemBuffer := s.txn.GetMemBuffer()
iter, err := txnMemBuffer.Iter(tblPrefix, endKey)
if err != nil {
return err
}

for iter.Valid() {
key := iter.Key()
if !bytes.HasPrefix(key, tblPrefix) {
break
}

value := iter.Value()
if len(value) == 0 {
err = sessionData.Delete(key)
} else {
err = sessionData.Set(key, iter.Value())
}

if err != nil {
return err
tiancaiamao marked this conversation as resolved.
Show resolved Hide resolved
}

err = iter.Next()
if err != nil {
return err
}
}
}

err := txn.Commit(ctx)
if err != nil {
return err
}

if stage != kv.InvalidStagingHandle {
sessionData.Release(stage)
s.sessionVars.TemporaryTableData = sessionData
stage = kv.InvalidStagingHandle
}

return nil
}

type temporaryTableKVFilter map[int64]tableutil.TempTable
Expand Down
95 changes: 95 additions & 0 deletions session/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4859,3 +4859,98 @@ func (s *testSessionSuite) TestAuthPluginForUser(c *C) {
c.Assert(err, IsNil)
c.Assert(plugin, Equals, "")
}

func (s *testSessionSuite) TestLocalTemporaryTableInsert(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("set @@tidb_enable_noop_functions=1")
tk.MustExec("use test")
tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)")
tk.MustExec("insert into tmp1 (u, v) values(11, 101)")
tk.MustExec("insert into tmp1 (u, v) values(12, 102)")
tk.MustExec("insert into tmp1 values(3, 13, 102)")

checkRecordOneTwoThreeAndNonExist := func() {
tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101"))
tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 102"))
tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 102"))
tk.MustQuery("select * from tmp1 where id=99").Check(testkit.Rows())
}

// inserted records exist
checkRecordOneTwoThreeAndNonExist()

// insert dup records out txn must be error
_, err := tk.Exec("insert into tmp1 values(1, 999, 9999)")
c.Assert(kv.ErrKeyExists.Equal(err), IsTrue)
checkRecordOneTwoThreeAndNonExist()

_, err = tk.Exec("insert into tmp1 values(99, 11, 999)")
c.Assert(kv.ErrKeyExists.Equal(err), IsTrue)
checkRecordOneTwoThreeAndNonExist()

// insert dup records in txn must be error
tk.MustExec("begin")
_, err = tk.Exec("insert into tmp1 values(1, 999, 9999)")
c.Assert(kv.ErrKeyExists.Equal(err), IsTrue)
checkRecordOneTwoThreeAndNonExist()

_, err = tk.Exec("insert into tmp1 values(99, 11, 9999)")
c.Assert(kv.ErrKeyExists.Equal(err), IsTrue)
checkRecordOneTwoThreeAndNonExist()

tk.MustExec("insert into tmp1 values(4, 14, 104)")
tk.MustQuery("select * from tmp1 where id=4").Check(testkit.Rows("4 14 104"))

_, err = tk.Exec("insert into tmp1 values(4, 999, 9999)")
c.Assert(kv.ErrKeyExists.Equal(err), IsTrue)

_, err = tk.Exec("insert into tmp1 values(99, 14, 9999)")
c.Assert(kv.ErrKeyExists.Equal(err), IsTrue)

checkRecordOneTwoThreeAndNonExist()
tk.MustExec("commit")

// check committed insert works
checkRecordOneTwoThreeAndNonExist()
tk.MustQuery("select * from tmp1 where id=4").Check(testkit.Rows("4 14 104"))

// check rollback works
tk.MustExec("begin")
tk.MustExec("insert into tmp1 values(5, 15, 105)")
tk.MustQuery("select * from tmp1 where id=5").Check(testkit.Rows("5 15 105"))
tk.MustExec("rollback")
tk.MustQuery("select * from tmp1 where id=5").Check(testkit.Rows())
}

func (s *testSessionSuite) TestLocalTemporaryTablePointGet(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("set @@tidb_enable_noop_functions=1")
tk.MustExec("use test")
tk.MustExec("create temporary table tmp1 (id int primary key auto_increment, u int unique, v int)")
tk.MustExec("insert into tmp1 values(1, 11, 101)")
tk.MustExec("insert into tmp1 values(2, 12, 102)")

// check point get out transaction
tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101"))
tk.MustQuery("select * from tmp1 where u=11").Check(testkit.Rows("1 11 101"))
tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 102"))
tk.MustQuery("select * from tmp1 where u=12").Check(testkit.Rows("2 12 102"))

// check point get in transaction
tk.MustExec("begin")
tk.MustQuery("select * from tmp1 where id=1").Check(testkit.Rows("1 11 101"))
tk.MustQuery("select * from tmp1 where u=11").Check(testkit.Rows("1 11 101"))
tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 102"))
tk.MustQuery("select * from tmp1 where u=12").Check(testkit.Rows("2 12 102"))
tk.MustExec("insert into tmp1 values(3, 13, 103)")
tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 103"))
tk.MustQuery("select * from tmp1 where u=13").Check(testkit.Rows("3 13 103"))
tk.MustExec("update tmp1 set v=999 where id=2")
tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 999"))
tk.MustExec("commit")

// check point get after transaction
tk.MustQuery("select * from tmp1 where id=3").Check(testkit.Rows("3 13 103"))
tk.MustQuery("select * from tmp1 where u=13").Check(testkit.Rows("3 13 103"))
tk.MustQuery("select * from tmp1 where id=2").Check(testkit.Rows("2 12 999"))
}
41 changes: 41 additions & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ package variable

import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"fmt"
Expand Down Expand Up @@ -859,6 +860,9 @@ type SessionVars struct {
// LocalTemporaryTables is *infoschema.LocalTemporaryTables, use interface to avoid circle dependency.
// It's nil if there is no local temporary table.
LocalTemporaryTables interface{}

// TemporaryTableData stores committed kv values for temporary table for current session.
TemporaryTableData kv.MemBuffer
}

// AllocMPPTaskID allocates task id for mpp tasks. It will reset the task id if the query's
Expand Down Expand Up @@ -2199,3 +2203,40 @@ func (s *SessionVars) GetSeekFactor(tbl *model.TableInfo) float64 {
}
return s.seekFactor
}

// GetTemporaryTableSnapshotValue get temporary table value from session
func (s *SessionVars) GetTemporaryTableSnapshotValue(ctx context.Context, key kv.Key) ([]byte, error) {
memData := s.TemporaryTableData
if memData == nil {
return nil, kv.ErrNotExist
}

v, err := memData.Get(ctx, key)
if err != nil {
return v, err
}

if len(v) == 0 {
return nil, kv.ErrNotExist
}

return v, nil
}

// GetTemporaryTableTxnValue returns a kv.Getter to fetch temporary table data in txn
func (s *SessionVars) GetTemporaryTableTxnValue(ctx context.Context, txn kv.Transaction, key kv.Key) ([]byte, error) {
v, err := txn.GetMemBuffer().Get(ctx, key)
if err == nil {
if len(v) == 0 {
return nil, kv.ErrNotExist
}

return v, nil
}

if !kv.IsErrNotFound(err) {
return v, err
}

return s.GetTemporaryTableSnapshotValue(ctx, key)
}
5 changes: 4 additions & 1 deletion table/tables/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ func (c *index) Create(sctx sessionctx.Context, txn kv.Transaction, indexedValue
}

var value []byte
if sctx.GetSessionVars().LazyCheckKeyNotExists() {
if c.tblInfo.TempTableType != model.TempTableNone {
lcwangchao marked this conversation as resolved.
Show resolved Hide resolved
// Always check key for temporary table because it does not write to TiKV
value, err = sctx.GetSessionVars().GetTemporaryTableTxnValue(ctx, txn, key)
} else if sctx.GetSessionVars().LazyCheckKeyNotExists() {
value, err = txn.GetMemBuffer().Get(ctx, key)
} else {
value, err = txn.Get(ctx, key)
Expand Down
13 changes: 12 additions & 1 deletion table/tables/tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,10 @@ func (t *TableCommon) AddRecord(sctx sessionctx.Context, r []types.Datum, opts .
var setPresume bool
skipCheck := sctx.GetSessionVars().StmtCtx.BatchCheck
if (t.meta.IsCommonHandle || t.meta.PKIsHandle) && !skipCheck && !opt.SkipHandleCheck {
if sctx.GetSessionVars().LazyCheckKeyNotExists() {
if t.meta.TempTableType != model.TempTableNone {
// Always check key for temporary table because it does not write to TiKV
_, err = sctx.GetSessionVars().GetTemporaryTableTxnValue(ctx, txn, key)
} else if sctx.GetSessionVars().LazyCheckKeyNotExists() {
var v []byte
v, err = txn.GetMemBuffer().Get(ctx, key)
if err != nil {
Expand Down Expand Up @@ -1827,6 +1830,8 @@ type TemporaryTable struct {
autoIDAllocator autoid.Allocator
// Table size.
size int64

meta *model.TableInfo
}

// TempTableFromMeta builds a TempTable from model.TableInfo.
Expand All @@ -1835,6 +1840,7 @@ func TempTableFromMeta(tblInfo *model.TableInfo) tableutil.TempTable {
modified: false,
stats: statistics.PseudoTable(tblInfo),
autoIDAllocator: autoid.NewAllocatorFromTempTblInfo(tblInfo),
meta: tblInfo,
}
}

Expand Down Expand Up @@ -1867,3 +1873,8 @@ func (t *TemporaryTable) GetSize() int64 {
func (t *TemporaryTable) SetSize(v int64) {
t.size = v
}

// GetMeta gets the table meta.
func (t *TemporaryTable) GetMeta() *model.TableInfo {
return t.meta
}
2 changes: 2 additions & 0 deletions util/tableutil/tableutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ type TempTable interface {

GetSize() int64
SetSize(int64)

GetMeta() *model.TableInfo
}

// TempTableFromMeta builds a TempTable from *model.TableInfo.
Expand Down