diff --git a/executor/batch_point_get.go b/executor/batch_point_get.go index 83423669380af..122a858e9631d 100644 --- a/executor/batch_point_get.go +++ b/executor/batch_point_get.go @@ -112,7 +112,8 @@ func (e *BatchPointGetExec) Next(ctx context.Context, req *chunk.Chunk) error { func (e *BatchPointGetExec) initialize(ctx context.Context) error { e.snapshotTS = e.startTS - txnCtx := e.ctx.GetSessionVars().TxnCtx + sessVars := e.ctx.GetSessionVars() + txnCtx := sessVars.TxnCtx if e.lock { e.snapshotTS = txnCtx.GetForUpdateTS() } @@ -122,7 +123,7 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { } e.txn = txn var snapshot kv.Snapshot - if txn.Valid() && txnCtx.StartTS == txnCtx.GetForUpdateTS() { + if sessVars.InTxn() && txnCtx.StartTS == txnCtx.GetForUpdateTS() { // We can safely reuse the transaction snapshot if startTS is equal to forUpdateTS. // The snapshot may contains cache that can reduce RPC call. snapshot = txn.GetSnapshot() @@ -137,7 +138,7 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error { } snapshot.SetOption(kv.TaskID, e.ctx.GetSessionVars().StmtCtx.TaskID) var batchGetter kv.BatchGetter = snapshot - if txn.Valid() { + if sessVars.InTxn() { batchGetter = kv.NewBufferBatchGetter(txn.GetMemBuffer(), &PessimisticLockCacheGetter{txnCtx: txnCtx}, snapshot) } diff --git a/executor/point_get.go b/executor/point_get.go index aa0e6a2b37552..a12b2633ee3aa 100644 --- a/executor/point_get.go +++ b/executor/point_get.go @@ -127,7 +127,8 @@ func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { return nil } e.done = true - txnCtx := e.ctx.GetSessionVars().TxnCtx + sessVars := e.ctx.GetSessionVars() + txnCtx := sessVars.TxnCtx snapshotTS := e.startTS if e.lock { snapshotTS = txnCtx.GetForUpdateTS() @@ -137,7 +138,7 @@ func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.Chunk) error { if err != nil { return err } - if e.txn.Valid() && txnCtx.StartTS == txnCtx.GetForUpdateTS() { + if sessVars.InTxn() && txnCtx.StartTS == txnCtx.GetForUpdateTS() { e.snapshot = e.txn.GetSnapshot() } else { e.snapshot, err = e.ctx.GetStore().GetSnapshot(kv.Version{Ver: snapshotTS}) @@ -277,7 +278,7 @@ func (e *PointGetExecutor) lockKeyIfNeeded(ctx context.Context, key []byte) erro // get will first try to get from txn buffer, then check the pessimistic lock cache, // then the store. Kv.ErrNotExist will be returned if key is not found func (e *PointGetExecutor) get(ctx context.Context, key kv.Key) ([]byte, error) { - if e.txn.Valid() && !e.txn.IsReadOnly() { + if e.ctx.GetSessionVars().InTxn() && !e.txn.IsReadOnly() { // We cannot use txn.Get directly here because the snapshot in txn and the snapshot of e.snapshot may be // different for pessimistic transaction. val, err := e.txn.GetMemBuffer().Get(ctx, key) diff --git a/store/tikv/snapshot.go b/store/tikv/snapshot.go index 2376c82adde5a..b17d8eb72b21d 100644 --- a/store/tikv/snapshot.go +++ b/store/tikv/snapshot.go @@ -68,7 +68,10 @@ type tikvSnapshot struct { // cached use len(value)=0 to represent a key-value entry doesn't exist (a reliable truth from TiKV). // In the BatchGet API, it use no key-value entry to represent non-exist. // It's OK as long as there are no zero-byte values in the protocol. - cached map[string][]byte + mu struct { + sync.RWMutex + cached map[string][]byte + } } // newTiKVSnapshot creates a snapshot of an TiKV store. @@ -88,7 +91,9 @@ func newTiKVSnapshot(store *tikvStore, ver kv.Version, replicaReadSeed uint32) * func (s *tikvSnapshot) setSnapshotTS(ts uint64) { // Invalidate cache if the snapshotTS change! s.version.Ver = ts - s.cached = nil + s.mu.Lock() + s.mu.cached = nil + s.mu.Unlock() // And also the minCommitTS pushed information. s.minCommitTSPushed.data = make(map[uint64]struct{}, 5) } @@ -98,10 +103,11 @@ func (s *tikvSnapshot) setSnapshotTS(ts uint64) { func (s *tikvSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) { // Check the cached value first. m := make(map[string][]byte) - if s.cached != nil { + s.mu.RLock() + if s.mu.cached != nil { tmp := keys[:0] for _, key := range keys { - if val, ok := s.cached[string(key)]; ok { + if val, ok := s.mu.cached[string(key)]; ok { if len(val) > 0 { m[string(key)] = val } @@ -111,6 +117,7 @@ func (s *tikvSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string] } keys = tmp } + s.mu.RUnlock() if len(keys) == 0 { return m, nil @@ -142,12 +149,14 @@ func (s *tikvSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string] } // Update the cache. - if s.cached == nil { - s.cached = make(map[string][]byte, len(m)) + s.mu.Lock() + if s.mu.cached == nil { + s.mu.cached = make(map[string][]byte, len(m)) } for _, key := range keys { - s.cached[string(key)] = m[string(key)] + s.mu.cached[string(key)] = m[string(key)] } + s.mu.Unlock() return m, nil } @@ -314,11 +323,14 @@ func (s *tikvSnapshot) Get(ctx context.Context, k kv.Key) ([]byte, error) { func (s *tikvSnapshot) get(bo *Backoffer, k kv.Key) ([]byte, error) { // Check the cached values first. - if s.cached != nil { - if value, ok := s.cached[string(k)]; ok { + s.mu.RLock() + if s.mu.cached != nil { + if value, ok := s.mu.cached[string(k)]; ok { + s.mu.RUnlock() return value, nil } } + s.mu.RUnlock() failpoint.Inject("snapshot-get-cache-fail", func(_ failpoint.Value) { if bo.ctx.Value("TestSnapshotCache") != nil { diff --git a/store/tikv/snapshot_test.go b/store/tikv/snapshot_test.go index 024617ce7802a..38a5a8101a2cd 100644 --- a/store/tikv/snapshot_test.go +++ b/store/tikv/snapshot_test.go @@ -16,6 +16,7 @@ package tikv import ( "context" "fmt" + "sync" "time" . "github.com/pingcap/check" @@ -274,3 +275,28 @@ func (s *testSnapshotSuite) TestPointGetSkipTxnLock(c *C) { c.Assert(value, BytesEquals, []byte("y")) c.Assert(time.Since(start), Less, 500*time.Millisecond) } + +func (s *testSnapshotSuite) TestSnapshotThreadSafe(c *C) { + txn := s.beginTxn(c) + key := kv.Key("key_test_snapshot_threadsafe") + c.Assert(txn.Set(key, []byte("x")), IsNil) + ctx := context.Background() + err := txn.Commit(context.Background()) + c.Assert(err, IsNil) + + snapshot := newTiKVSnapshot(s.store, kv.MaxVersion, 0) + var wg sync.WaitGroup + wg.Add(5) + for i := 0; i < 5; i++ { + go func() { + for i := 0; i < 30; i++ { + _, err := snapshot.Get(ctx, key) + c.Assert(err, IsNil) + _, err = snapshot.BatchGet(ctx, []kv.Key{key, kv.Key("key_not_exist")}) + c.Assert(err, IsNil) + } + wg.Done() + }() + } + wg.Wait() +}