diff --git a/ddl/index.go b/ddl/index.go index 6da9d079f0411..47d5286c84ffc 100644 --- a/ddl/index.go +++ b/ddl/index.go @@ -691,7 +691,7 @@ func (w *addIndexWorker) batchCheckUniqueKey(txn kv.Transaction, idxRecords []*i w.distinctCheckFlags = append(w.distinctCheckFlags, distinct) } - batchVals, err := kv.BatchGetValues(txn, w.batchCheckKeys) + batchVals, err := txn.BatchGet(w.batchCheckKeys) if err != nil { return errors.Trace(err) } diff --git a/executor/admin.go b/executor/admin.go index 4fdcc50a02bf5..554b7074e4aba 100644 --- a/executor/admin.go +++ b/executor/admin.go @@ -376,7 +376,7 @@ func (e *RecoverIndexExec) batchMarkDup(txn kv.Transaction, rows []recoverRows) distinctFlags[i] = distinct } - values, err := kv.BatchGetValues(txn, e.batchKeys) + values, err := txn.BatchGet(e.batchKeys) if err != nil { return errors.Trace(err) } @@ -500,7 +500,7 @@ func (e *CleanupIndexExec) batchGetRecord(txn kv.Transaction) (map[string][]byte for handle := range e.idxValues { e.batchKeys = append(e.batchKeys, e.table.RecordKey(handle)) } - values, err := kv.BatchGetValues(txn, e.batchKeys) + values, err := txn.BatchGet(e.batchKeys) if err != nil { return nil, errors.Trace(err) } diff --git a/executor/batch_checker.go b/executor/batch_checker.go index 6baaff8c2f235..d151c7dd85745 100644 --- a/executor/batch_checker.go +++ b/executor/batch_checker.go @@ -56,7 +56,7 @@ func (b *batchChecker) batchGetOldValues(ctx sessionctx.Context, batchKeys []kv. if err != nil { return errors.Trace(err) } - values, err := kv.BatchGetValues(txn, batchKeys) + values, err := txn.BatchGet(batchKeys) if err != nil { return errors.Trace(err) } @@ -213,7 +213,7 @@ func (b *batchChecker) batchGetInsertKeys(ctx sessionctx.Context, t table.Table, if err != nil { return errors.Trace(err) } - b.dupKVs, err = kv.BatchGetValues(txn, batchKeys) + b.dupKVs, err = txn.BatchGet(batchKeys) return errors.Trace(err) } diff --git a/executor/write_test.go b/executor/write_test.go index 05e3687732e03..00a0c26b0cd78 100644 --- a/executor/write_test.go +++ b/executor/write_test.go @@ -285,6 +285,17 @@ func (s *testSuite) TestInsert(c *C) { tk.MustQuery("select * from t").Check(testkit.Rows("1 1")) } +func (s *testSuite) TestMultiBatch(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table t0 (i int)") + tk.MustExec("insert into t0 values (1), (1)") + tk.MustExec("create table t (i int unique key)") + tk.MustExec("set @@tidb_dml_batch_size = 1") + tk.MustExec("insert ignore into t select * from t0") + tk.MustExec("admin check table t") +} + func (s *testSuite) TestInsertAutoInc(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") diff --git a/kv/fault_injection.go b/kv/fault_injection.go index 1a6ac23796186..e36755ff92905 100644 --- a/kv/fault_injection.go +++ b/kv/fault_injection.go @@ -98,6 +98,16 @@ func (t *InjectedTransaction) Get(k Key) ([]byte, error) { return t.Transaction.Get(k) } +// BatchGet returns an error if cfg.getError is set. +func (t *InjectedTransaction) BatchGet(keys []Key) (map[string][]byte, error) { + t.cfg.RLock() + defer t.cfg.RUnlock() + if t.cfg.getError != nil { + return nil, t.cfg.getError + } + return t.Transaction.BatchGet(keys) +} + // Commit returns an error if cfg.commitError is set. func (t *InjectedTransaction) Commit(ctx context.Context) error { t.cfg.RLock() @@ -108,14 +118,6 @@ func (t *InjectedTransaction) Commit(ctx context.Context) error { return t.Transaction.Commit(ctx) } -// GetSnapshot implements Transaction GetSnapshot method. -func (t *InjectedTransaction) GetSnapshot() Snapshot { - return &InjectedSnapshot{ - Snapshot: t.Transaction.GetSnapshot(), - cfg: t.cfg, - } -} - // InjectedSnapshot wraps a Snapshot with injections. type InjectedSnapshot struct { Snapshot diff --git a/kv/kv.go b/kv/kv.go index 01ed6a8556eee..c26b463645bc5 100644 --- a/kv/kv.go +++ b/kv/kv.go @@ -150,10 +150,10 @@ type Transaction interface { Valid() bool // GetMemBuffer return the MemBuffer binding to this transaction. GetMemBuffer() MemBuffer - // GetSnapshot returns the snapshot of this transaction. - GetSnapshot() Snapshot // SetVars sets variables to the transaction. SetVars(vars *Variables) + // BatchGet gets kv from the memory buffer of statement and transaction, and the kv storage. + BatchGet(keys []Key) (map[string][]byte, error) } // Client is used to send request to KV layer. diff --git a/kv/mock.go b/kv/mock.go index a84618d264957..5d02a02242320 100644 --- a/kv/mock.go +++ b/kv/mock.go @@ -68,6 +68,10 @@ func (t *mockTxn) Get(k Key) ([]byte, error) { return nil, nil } +func (t *mockTxn) BatchGet(keys []Key) (map[string][]byte, error) { + return nil, nil +} + func (t *mockTxn) Iter(k Key, upperBound Key) (Iterator, error) { return nil, nil } @@ -99,12 +103,6 @@ func (t *mockTxn) GetMemBuffer() MemBuffer { return nil } -func (t *mockTxn) GetSnapshot() Snapshot { - return &mockSnapshot{ - store: NewMemDbBuffer(DefaultTxnMembufCap), - } -} - func (t *mockTxn) SetCap(cap int) { } diff --git a/kv/txn.go b/kv/txn.go index c1a2b3d913299..6136150513f1f 100644 --- a/kv/txn.go +++ b/kv/txn.go @@ -89,37 +89,3 @@ func BackOff(attempts uint) int { time.Sleep(sleep) return int(sleep) } - -// BatchGetValues gets values in batch. -// The values from buffer in transaction and the values from the storage node are merged together. -func BatchGetValues(txn Transaction, keys []Key) (map[string][]byte, error) { - if txn.IsReadOnly() { - return txn.GetSnapshot().BatchGet(keys) - } - bufferValues := make([][]byte, len(keys)) - shrinkKeys := make([]Key, 0, len(keys)) - for i, key := range keys { - val, err := txn.GetMemBuffer().Get(key) - if IsErrNotFound(err) { - shrinkKeys = append(shrinkKeys, key) - continue - } - if err != nil { - return nil, errors.Trace(err) - } - if len(val) != 0 { - bufferValues[i] = val - } - } - storageValues, err := txn.GetSnapshot().BatchGet(shrinkKeys) - if err != nil { - return nil, errors.Trace(err) - } - for i, key := range keys { - if bufferValues[i] == nil { - continue - } - storageValues[string(key)] = bufferValues[i] - } - return storageValues, nil -} diff --git a/session/txn.go b/session/txn.go index 935e80c4a6d03..3b6810b7360bc 100644 --- a/session/txn.go +++ b/session/txn.go @@ -212,6 +212,36 @@ func (st *TxnState) Get(k kv.Key) ([]byte, error) { return val, nil } +// BatchGet overrides the Transaction interface. +func (st *TxnState) BatchGet(keys []kv.Key) (map[string][]byte, error) { + bufferValues := make([][]byte, len(keys)) + shrinkKeys := make([]kv.Key, 0, len(keys)) + for i, key := range keys { + val, err := st.buf.Get(key) + if kv.IsErrNotFound(err) { + shrinkKeys = append(shrinkKeys, key) + continue + } + if err != nil { + return nil, errors.Trace(err) + } + if len(val) != 0 { + bufferValues[i] = val + } + } + storageValues, err := st.Transaction.BatchGet(shrinkKeys) + if err != nil { + return nil, errors.Trace(err) + } + for i, key := range keys { + if bufferValues[i] == nil { + continue + } + storageValues[string(key)] = bufferValues[i] + } + return storageValues, nil +} + // Set overrides the Transaction interface. func (st *TxnState) Set(k kv.Key, v []byte) error { return st.buf.Set(k, v) diff --git a/store/tikv/txn.go b/store/tikv/txn.go index 9f55fc84f7f2a..6a1d39e5d72f2 100644 --- a/store/tikv/txn.go +++ b/store/tikv/txn.go @@ -108,6 +108,38 @@ func (txn *tikvTxn) Get(k kv.Key) ([]byte, error) { return ret, nil } +func (txn *tikvTxn) BatchGet(keys []kv.Key) (map[string][]byte, error) { + if txn.IsReadOnly() { + return txn.snapshot.BatchGet(keys) + } + bufferValues := make([][]byte, len(keys)) + shrinkKeys := make([]kv.Key, 0, len(keys)) + for i, key := range keys { + val, err := txn.GetMemBuffer().Get(key) + if kv.IsErrNotFound(err) { + shrinkKeys = append(shrinkKeys, key) + continue + } + if err != nil { + return nil, errors.Trace(err) + } + if len(val) != 0 { + bufferValues[i] = val + } + } + storageValues, err := txn.snapshot.BatchGet(shrinkKeys) + if err != nil { + return nil, errors.Trace(err) + } + for i, key := range keys { + if bufferValues[i] == nil { + continue + } + storageValues[string(key)] = bufferValues[i] + } + return storageValues, nil +} + func (txn *tikvTxn) Set(k kv.Key, v []byte) error { txn.setCnt++ @@ -287,7 +319,3 @@ func (txn *tikvTxn) Size() int { func (txn *tikvTxn) GetMemBuffer() kv.MemBuffer { return txn.us.GetMemBuffer() } - -func (txn *tikvTxn) GetSnapshot() kv.Snapshot { - return txn.snapshot -}