diff --git a/store/driver/txn/batch_getter.go b/store/driver/txn/batch_getter.go index 272ae0fd9843e..19f588373d1cc 100644 --- a/store/driver/txn/batch_getter.go +++ b/store/driver/txn/batch_getter.go @@ -16,11 +16,60 @@ package txn import ( "context" + "unsafe" - "github.com/pingcap/errors" "github.com/pingcap/tidb/kv" + tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/txnkv/transaction" ) +// tikvBatchGetter is the BatchGetter struct for tikv +// In order to directly call NewBufferBatchGetter in client-go +// We need to implement the interface (transaction.BatchGetter) in client-go for tikvBatchGetter +type tikvBatchGetter struct { + tidbBatchGetter BatchGetter +} + +func (b tikvBatchGetter) BatchGet(ctx context.Context, keys [][]byte) (map[string][]byte, error) { + // toTiDBKeys + kvKeys := *(*[]kv.Key)(unsafe.Pointer(&keys)) + vals, err := b.tidbBatchGetter.BatchGet(ctx, kvKeys) + return vals, err +} + +// tikvBatchBufferGetter is the BatchBufferGetter struct for tikv +// In order to directly call NewBufferBatchGetter in client-go +// We need to implement the interface (transaction.BatchBufferGetter) in client-go for tikvBatchBufferGetter +type tikvBatchBufferGetter struct { + tidbMiddleCache Getter + tidbBuffer BatchBufferGetter +} + +func (b tikvBatchBufferGetter) Get(k []byte) ([]byte, error) { + // Get from buffer + val, err := b.tidbBuffer.Get(context.TODO(), k) + if err == nil || !kv.IsErrNotFound(err) || b.tidbMiddleCache == nil { + if kv.IsErrNotFound(err) { + err = tikverr.ErrNotExist + } + return val, err + } + // Get from middle cache + val, err = b.tidbMiddleCache.Get(context.TODO(), k) + if err == nil { + return val, err + } + // TiDB err NotExist to TiKV err NotExist + // The BatchGet method in client-go will call this method + // Therefore, the error needs to convert to TiKV's type, otherwise the error will not be handled properly in client-go + err = tikverr.ErrNotExist + return val, err +} + +func (b tikvBatchBufferGetter) Len() int { + return b.tidbBuffer.Len() +} + // BatchBufferGetter is the interface for BatchGet. type BatchBufferGetter interface { Len() int @@ -42,50 +91,20 @@ type Getter interface { // BufferBatchGetter is the type for BatchGet with MemBuffer. type BufferBatchGetter struct { - buffer BatchBufferGetter - middle Getter - snapshot BatchGetter + tikvBufferBatchGetter transaction.BufferBatchGetter } // NewBufferBatchGetter creates a new BufferBatchGetter. func NewBufferBatchGetter(buffer BatchBufferGetter, middleCache Getter, snapshot BatchGetter) *BufferBatchGetter { - return &BufferBatchGetter{buffer: buffer, middle: middleCache, snapshot: snapshot} + tikvBuffer := tikvBatchBufferGetter{tidbMiddleCache: middleCache, tidbBuffer: buffer} + tikvSnapshot := tikvBatchGetter{snapshot} + return &BufferBatchGetter{tikvBufferBatchGetter: *transaction.NewBufferBatchGetter(tikvBuffer, tikvSnapshot)} } // BatchGet implements the BatchGetter interface. func (b *BufferBatchGetter) BatchGet(ctx context.Context, keys []kv.Key) (map[string][]byte, error) { - if b.buffer.Len() == 0 { - return b.snapshot.BatchGet(ctx, keys) - } - bufferValues := make([][]byte, len(keys)) - shrinkKeys := make([]kv.Key, 0, len(keys)) - for i, key := range keys { - val, err := b.buffer.Get(ctx, key) - if err == nil { - bufferValues[i] = val - continue - } - if !kv.IsErrNotFound(err) { - return nil, errors.Trace(err) - } - if b.middle != nil { - val, err = b.middle.Get(ctx, key) - if err == nil { - bufferValues[i] = val - continue - } - } - shrinkKeys = append(shrinkKeys, key) - } - storageValues, err := b.snapshot.BatchGet(ctx, shrinkKeys) - if err != nil { - return nil, errors.Trace(err) - } - for i, key := range keys { - if len(bufferValues[i]) == 0 { - continue - } - storageValues[string(key)] = bufferValues[i] - } - return storageValues, nil + tikvKeys := toTiKVKeys(keys) + storageValues, err := b.tikvBufferBatchGetter.BatchGet(ctx, tikvKeys) + + return storageValues, err }