Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

store/tikv: make snapshot thread safe (#17593) #18145

Merged
merged 11 commits into from
Jun 30, 2020
7 changes: 4 additions & 3 deletions executor/batch_point_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ func (e *BatchPointGetExec) Next(ctx context.Context, req *chunk.Chunk) error {

func (e *BatchPointGetExec) initialize(ctx context.Context) error {
e.snapshotTS = e.startTS
txnCtx := e.ctx.GetSessionVars().TxnCtx
sessVars := e.ctx.GetSessionVars()
txnCtx := sessVars.TxnCtx
if e.lock {
e.snapshotTS = txnCtx.GetForUpdateTS()
}
Expand All @@ -122,7 +123,7 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error {
}
e.txn = txn
var snapshot kv.Snapshot
if txn.Valid() && txnCtx.StartTS == txnCtx.GetForUpdateTS() {
if sessVars.InTxn() && txnCtx.StartTS == txnCtx.GetForUpdateTS() {
// We can safely reuse the transaction snapshot if startTS is equal to forUpdateTS.
// The snapshot may contains cache that can reduce RPC call.
snapshot = txn.GetSnapshot()
Expand All @@ -137,7 +138,7 @@ func (e *BatchPointGetExec) initialize(ctx context.Context) error {
}
snapshot.SetOption(kv.TaskID, e.ctx.GetSessionVars().StmtCtx.TaskID)
var batchGetter kv.BatchGetter = snapshot
if txn.Valid() {
if sessVars.InTxn() {
batchGetter = kv.NewBufferBatchGetter(txn.GetMemBuffer(), &PessimisticLockCacheGetter{txnCtx: txnCtx}, snapshot)
}

Expand Down
7 changes: 4 additions & 3 deletions executor/point_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.Chunk) error {
return nil
}
e.done = true
txnCtx := e.ctx.GetSessionVars().TxnCtx
sessVars := e.ctx.GetSessionVars()
txnCtx := sessVars.TxnCtx
snapshotTS := e.startTS
if e.lock {
snapshotTS = txnCtx.GetForUpdateTS()
Expand All @@ -137,7 +138,7 @@ func (e *PointGetExecutor) Next(ctx context.Context, req *chunk.Chunk) error {
if err != nil {
return err
}
if e.txn.Valid() && txnCtx.StartTS == txnCtx.GetForUpdateTS() {
if sessVars.InTxn() && txnCtx.StartTS == txnCtx.GetForUpdateTS() {
e.snapshot = e.txn.GetSnapshot()
} else {
e.snapshot, err = e.ctx.GetStore().GetSnapshot(kv.Version{Ver: snapshotTS})
Expand Down Expand Up @@ -277,7 +278,7 @@ func (e *PointGetExecutor) lockKeyIfNeeded(ctx context.Context, key []byte) erro
// get will first try to get from txn buffer, then check the pessimistic lock cache,
// then the store. Kv.ErrNotExist will be returned if key is not found
func (e *PointGetExecutor) get(ctx context.Context, key kv.Key) ([]byte, error) {
if e.txn.Valid() && !e.txn.IsReadOnly() {
if e.ctx.GetSessionVars().InTxn() && !e.txn.IsReadOnly() {
// We cannot use txn.Get directly here because the snapshot in txn and the snapshot of e.snapshot may be
// different for pessimistic transaction.
val, err := e.txn.GetMemBuffer().Get(ctx, key)
Expand Down
30 changes: 21 additions & 9 deletions store/tikv/snapshot.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,10 @@ type tikvSnapshot struct {
// cached use len(value)=0 to represent a key-value entry doesn't exist (a reliable truth from TiKV).
// In the BatchGet API, it use no key-value entry to represent non-exist.
// It's OK as long as there are no zero-byte values in the protocol.
cached map[string][]byte
mu struct {
sync.RWMutex
cached map[string][]byte
}
}

// newTiKVSnapshot creates a snapshot of an TiKV store.
Expand All @@ -88,7 +91,9 @@ func newTiKVSnapshot(store *tikvStore, ver kv.Version, replicaReadSeed uint32) *
func (s *tikvSnapshot) setSnapshotTS(ts uint64) {
// Invalidate cache if the snapshotTS change!
s.version.Ver = ts
s.cached = nil
s.mu.Lock()
s.mu.cached = nil
s.mu.Unlock()
// And also the minCommitTS pushed information.
s.minCommitTSPushed.data = make(map[uint64]struct{}, 5)
}
Expand All @@ -98,10 +103,11 @@ func (s *tikvSnapshot) setSnapshotTS(ts uint64) {
func (s *tikvSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) {
// Check the cached value first.
m := make(map[string][]byte)
if s.cached != nil {
s.mu.RLock()
if s.mu.cached != nil {
tmp := keys[:0]
for _, key := range keys {
if val, ok := s.cached[string(key)]; ok {
if val, ok := s.mu.cached[string(key)]; ok {
if len(val) > 0 {
m[string(key)] = val
}
Expand All @@ -111,6 +117,7 @@ func (s *tikvSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string]
}
keys = tmp
}
s.mu.RUnlock()

if len(keys) == 0 {
return m, nil
Expand Down Expand Up @@ -142,12 +149,14 @@ func (s *tikvSnapshot) BatchGet(ctx context.Context, keys []kv.Key) (map[string]
}

// Update the cache.
if s.cached == nil {
s.cached = make(map[string][]byte, len(m))
s.mu.Lock()
if s.mu.cached == nil {
s.mu.cached = make(map[string][]byte, len(m))
}
for _, key := range keys {
s.cached[string(key)] = m[string(key)]
s.mu.cached[string(key)] = m[string(key)]
}
s.mu.Unlock()

return m, nil
}
Expand Down Expand Up @@ -314,11 +323,14 @@ func (s *tikvSnapshot) Get(ctx context.Context, k kv.Key) ([]byte, error) {

func (s *tikvSnapshot) get(bo *Backoffer, k kv.Key) ([]byte, error) {
// Check the cached values first.
if s.cached != nil {
if value, ok := s.cached[string(k)]; ok {
s.mu.RLock()
if s.mu.cached != nil {
if value, ok := s.mu.cached[string(k)]; ok {
s.mu.RUnlock()
return value, nil
}
}
s.mu.RUnlock()

failpoint.Inject("snapshot-get-cache-fail", func(_ failpoint.Value) {
if bo.ctx.Value("TestSnapshotCache") != nil {
Expand Down
26 changes: 26 additions & 0 deletions store/tikv/snapshot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package tikv
import (
"context"
"fmt"
"sync"
"time"

. "github.com/pingcap/check"
Expand Down Expand Up @@ -274,3 +275,28 @@ func (s *testSnapshotSuite) TestPointGetSkipTxnLock(c *C) {
c.Assert(value, BytesEquals, []byte("y"))
c.Assert(time.Since(start), Less, 500*time.Millisecond)
}

func (s *testSnapshotSuite) TestSnapshotThreadSafe(c *C) {
txn := s.beginTxn(c)
key := kv.Key("key_test_snapshot_threadsafe")
c.Assert(txn.Set(key, []byte("x")), IsNil)
ctx := context.Background()
err := txn.Commit(context.Background())
c.Assert(err, IsNil)

snapshot := newTiKVSnapshot(s.store, kv.MaxVersion, 0)
var wg sync.WaitGroup
wg.Add(5)
for i := 0; i < 5; i++ {
go func() {
for i := 0; i < 30; i++ {
_, err := snapshot.Get(ctx, key)
c.Assert(err, IsNil)
_, err = snapshot.BatchGet(ctx, []kv.Key{key, kv.Key("key_not_exist")})
c.Assert(err, IsNil)
}
wg.Done()
}()
}
wg.Wait()
}