From 41d133b6b615c83ca62fbbc98e0c0fe42b9d0ea0 Mon Sep 17 00:00:00 2001 From: you06 Date: Fri, 23 Aug 2024 12:46:43 +0900 Subject: [PATCH] membuffer: refactor the memdb to support multi implementations (#1426) ref pingcap/tidb#55287 Signed-off-by: you06 --- .../{memdb_arena.go => arena/arena.go} | 277 +++--- internal/unionstore/arena/arena_test.go | 79 ++ internal/unionstore/art/art.go | 176 ++++ internal/unionstore/art/art_arena.go | 35 + internal/unionstore/art/art_iterator.go | 31 + internal/unionstore/art/art_node.go | 58 ++ internal/unionstore/art/art_snapshot.go | 43 + internal/unionstore/memdb.go | 912 +---------------- internal/unionstore/memdb_art.go | 169 ++++ internal/unionstore/memdb_bench_test.go | 28 +- internal/unionstore/memdb_norace_test.go | 4 +- internal/unionstore/memdb_rbt.go | 176 ++++ internal/unionstore/memdb_test.go | 222 ++--- internal/unionstore/mock.go | 8 +- internal/unionstore/pipelined_memdb.go | 30 +- internal/unionstore/pipelined_memdb_test.go | 2 +- internal/unionstore/rbt/rbt.go | 926 ++++++++++++++++++ internal/unionstore/rbt/rbt_arena.go | 101 ++ .../rbt_iterator.go} | 74 +- .../rbt_snapshot.go} | 55 +- internal/unionstore/rbt/rbt_test.go | 170 ++++ internal/unionstore/union_store.go | 51 +- internal/unionstore/union_store_test.go | 8 +- txnkv/transaction/txn.go | 2 +- 24 files changed, 2307 insertions(+), 1330 deletions(-) rename internal/unionstore/{memdb_arena.go => arena/arena.go} (54%) create mode 100644 internal/unionstore/arena/arena_test.go create mode 100644 internal/unionstore/art/art.go create mode 100644 internal/unionstore/art/art_arena.go create mode 100644 internal/unionstore/art/art_iterator.go create mode 100644 internal/unionstore/art/art_node.go create mode 100644 internal/unionstore/art/art_snapshot.go create mode 100644 internal/unionstore/memdb_art.go create mode 100644 internal/unionstore/memdb_rbt.go create mode 100644 internal/unionstore/rbt/rbt.go create mode 100644 internal/unionstore/rbt/rbt_arena.go rename internal/unionstore/{memdb_iterator.go => rbt/rbt_iterator.go} (75%) rename internal/unionstore/{memdb_snapshot.go => rbt/rbt_snapshot.go} (71%) create mode 100644 internal/unionstore/rbt/rbt_test.go diff --git a/internal/unionstore/memdb_arena.go b/internal/unionstore/arena/arena.go similarity index 54% rename from internal/unionstore/memdb_arena.go rename to internal/unionstore/arena/arena.go index 921d80f17f..b6aa8dac9c 100644 --- a/internal/unionstore/memdb_arena.go +++ b/internal/unionstore/arena/arena.go @@ -32,19 +32,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -package unionstore +package arena import ( "encoding/binary" "math" - "unsafe" "github.com/tikv/client-go/v2/kv" "go.uber.org/atomic" ) const ( - alignMask = 1<<32 - 8 // 29 bit 1 and 3 bit 0. + alignMask = 0xFFFFFFF8 // 29 bits of 1 and 3 bits of 0 nullBlockOffset = math.MaxUint32 maxBlockSize = 128 << 20 @@ -52,34 +51,39 @@ const ( ) var ( - nullAddr = memdbArenaAddr{math.MaxUint32, math.MaxUint32} - nullNodeAddr = memdbNodeAddr{nil, nullAddr} - endian = binary.LittleEndian + Tombstone = []byte{} + NullAddr = MemdbArenaAddr{math.MaxUint32, math.MaxUint32} + BadAddr = MemdbArenaAddr{math.MaxUint32 - 1, math.MaxUint32} + endian = binary.LittleEndian ) -type memdbArenaAddr struct { +type MemdbArenaAddr struct { idx uint32 off uint32 } -func (addr memdbArenaAddr) isNull() bool { +func (addr MemdbArenaAddr) IsNull() bool { // Combine all checks into a single condition - return addr == nullAddr || addr.idx == math.MaxUint32 || addr.off == math.MaxUint32 + return addr == NullAddr || addr.idx == math.MaxUint32 || addr.off == math.MaxUint32 +} + +func (addr MemdbArenaAddr) ToHandle() MemKeyHandle { + return MemKeyHandle{idx: uint16(addr.idx), off: addr.off} } // store and load is used by vlog, due to pointer in vlog is not aligned. -func (addr memdbArenaAddr) store(dst []byte) { +func (addr MemdbArenaAddr) store(dst []byte) { endian.PutUint32(dst, addr.idx) endian.PutUint32(dst[4:], addr.off) } -func (addr *memdbArenaAddr) load(src []byte) { +func (addr *MemdbArenaAddr) load(src []byte) { addr.idx = endian.Uint32(src) addr.off = endian.Uint32(src[4:]) } -type memdbArena struct { +type MemdbArena struct { blockSize int blocks []memdbArenaBlock // the total size of all blocks, also the approximate memory footprint of the arena. @@ -88,7 +92,7 @@ type memdbArena struct { memChangeHook atomic.Pointer[func()] } -func (a *memdbArena) alloc(size int, align bool) (memdbArenaAddr, []byte) { +func (a *MemdbArena) Alloc(size int, align bool) (MemdbArenaAddr, []byte) { if size > maxBlockSize { panic("alloc size is larger than max block size") } @@ -98,7 +102,7 @@ func (a *memdbArena) alloc(size int, align bool) (memdbArenaAddr, []byte) { } addr, data := a.allocInLastBlock(size, align) - if !addr.isNull() { + if !addr.IsNull() { return addr, data } @@ -106,7 +110,7 @@ func (a *memdbArena) alloc(size int, align bool) (memdbArenaAddr, []byte) { return a.allocInLastBlock(size, align) } -func (a *memdbArena) enlarge(allocSize, blockSize int) { +func (a *MemdbArena) enlarge(allocSize, blockSize int) { a.blockSize = blockSize for a.blockSize <= allocSize { a.blockSize <<= 1 @@ -119,36 +123,59 @@ func (a *memdbArena) enlarge(allocSize, blockSize int) { buf: make([]byte, a.blockSize), }) a.capacity += uint64(a.blockSize) - // We shall not call a.onMemChange() here, since it will make the latest block empty, which breaks a precondition - // for some operations (e.g. revertToCheckpoint) + // We shall not call a.OnMemChange() here, since it will make the latest block empty, which breaks a precondition + // for some operations (e.g. RevertToCheckpoint) +} + +func (a *MemdbArena) Blocks() int { + return len(a.blocks) +} + +func (a *MemdbArena) Capacity() uint64 { + return a.capacity +} + +// SetMemChangeHook sets the hook function that will be called when the memory footprint of the arena changes. +func (a *MemdbArena) SetMemChangeHook(hook func()) { + a.memChangeHook.Store(&hook) } -// onMemChange should only be called right before exiting memdb. +// MemHookSet returns whether the memory change hook is set. +func (a *MemdbArena) MemHookSet() bool { + return a.memChangeHook.Load() != nil +} + +// OnMemChange should only be called right before exiting memdb. // This is because the hook can lead to a panic, and leave memdb in an inconsistent state. -func (a *memdbArena) onMemChange() { +func (a *MemdbArena) OnMemChange() { hook := a.memChangeHook.Load() if hook != nil { (*hook)() } } -func (a *memdbArena) allocInLastBlock(size int, align bool) (memdbArenaAddr, []byte) { +func (a *MemdbArena) allocInLastBlock(size int, align bool) (MemdbArenaAddr, []byte) { idx := len(a.blocks) - 1 offset, data := a.blocks[idx].alloc(size, align) if offset == nullBlockOffset { - return nullAddr, nil + return NullAddr, nil } - return memdbArenaAddr{uint32(idx), offset}, data + return MemdbArenaAddr{uint32(idx), offset}, data +} + +// GetData gets data slice of given addr, DO NOT access others data. +func (a *MemdbArena) GetData(addr MemdbArenaAddr) []byte { + return a.blocks[addr.idx].buf[addr.off:] } -func (a *memdbArena) reset() { +func (a *MemdbArena) Reset() { for i := range a.blocks { a.blocks[i].reset() } a.blocks = a.blocks[:0] a.blockSize = 0 a.capacity = 0 - a.onMemChange() + a.OnMemChange() } type memdbArenaBlock struct { @@ -176,18 +203,18 @@ func (a *memdbArenaBlock) reset() { a.length = 0 } -// MemDBCheckpoint is the checkpoint of memory DB. +// MemDBCheckpoint is the Checkpoint of memory DB. type MemDBCheckpoint struct { blockSize int blocks int offsetInBlock int } -func (cp *MemDBCheckpoint) isSamePosition(other *MemDBCheckpoint) bool { +func (cp *MemDBCheckpoint) IsSamePosition(other *MemDBCheckpoint) bool { return cp.blocks == other.blocks && cp.offsetInBlock == other.offsetInBlock } -func (a *memdbArena) checkpoint() MemDBCheckpoint { +func (a *MemdbArena) Checkpoint() MemDBCheckpoint { snap := MemDBCheckpoint{ blockSize: a.blockSize, blocks: len(a.blocks), @@ -198,7 +225,7 @@ func (a *memdbArena) checkpoint() MemDBCheckpoint { return snap } -func (a *memdbArena) truncate(snap *MemDBCheckpoint) { +func (a *MemdbArena) Truncate(snap *MemDBCheckpoint) { for i := snap.blocks; i < len(a.blocks); i++ { a.blocks[i] = memdbArenaBlock{} } @@ -212,203 +239,137 @@ func (a *memdbArena) truncate(snap *MemDBCheckpoint) { for _, block := range a.blocks { a.capacity += uint64(block.length) } - // We shall not call a.onMemChange() here, since it may cause a panic and leave memdb in an inconsistent state + // We shall not call a.OnMemChange() here, since it may cause a panic and leave memdb in an inconsistent state } -type nodeAllocator struct { - memdbArena - - // Dummy node, so that we can make X.left.up = X. - // We then use this instead of NULL to mean the top or bottom - // end of the rb tree. It is a black node. - nullNode memdbNode +// KeyFlagsGetter is an interface to get key and key flags, usually a leaf or node. +type KeyFlagsGetter interface { + GetKey() []byte + GetKeyFlags() kv.KeyFlags } -func (a *nodeAllocator) init() { - a.nullNode = memdbNode{ - up: nullAddr, - left: nullAddr, - right: nullAddr, - vptr: nullAddr, - } +// VlogMemDB is the interface of the memory buffer which supports vlog to revert node and inspect node. +type VlogMemDB[G KeyFlagsGetter] interface { + RevertNode(hdr *MemdbVlogHdr) + InspectNode(addr MemdbArenaAddr) (G, MemdbArenaAddr) } -func (a *nodeAllocator) getNode(addr memdbArenaAddr) *memdbNode { - if addr.isNull() { - return &a.nullNode - } - - return (*memdbNode)(unsafe.Pointer(&a.blocks[addr.idx].buf[addr.off])) -} - -func (a *nodeAllocator) allocNode(key []byte) (memdbArenaAddr, *memdbNode) { - nodeSize := 8*4 + 2 + kv.FlagBytes + len(key) - prevBlocks := len(a.blocks) - addr, mem := a.alloc(nodeSize, true) - n := (*memdbNode)(unsafe.Pointer(&mem[0])) - n.vptr = nullAddr - n.klen = uint16(len(key)) - copy(n.getKey(), key) - if prevBlocks != len(a.blocks) { - a.onMemChange() - } - return addr, n -} - -var testMode = false - -func (a *nodeAllocator) freeNode(addr memdbArenaAddr) { - if testMode { - // Make it easier for debug. - n := a.getNode(addr) - badAddr := nullAddr - badAddr.idx-- - n.left = badAddr - n.right = badAddr - n.up = badAddr - n.vptr = badAddr - return - } - // TODO: reuse freed nodes. Need to fix lastTraversedNode when implementing this. -} - -func (a *nodeAllocator) reset() { - a.memdbArena.reset() - a.init() -} - -type memdbVlog struct { - memdbArena - memdb *MemDB +type MemdbVlog[G KeyFlagsGetter, M VlogMemDB[G]] struct { + MemdbArena } const memdbVlogHdrSize = 8 + 8 + 4 -type memdbVlogHdr struct { - nodeAddr memdbArenaAddr - oldValue memdbArenaAddr - valueLen uint32 +type MemdbVlogHdr struct { + NodeAddr MemdbArenaAddr + OldValue MemdbArenaAddr + ValueLen uint32 } -func (hdr *memdbVlogHdr) store(dst []byte) { +func (hdr *MemdbVlogHdr) store(dst []byte) { cursor := 0 - endian.PutUint32(dst[cursor:], hdr.valueLen) + endian.PutUint32(dst[cursor:], hdr.ValueLen) cursor += 4 - hdr.oldValue.store(dst[cursor:]) + hdr.OldValue.store(dst[cursor:]) cursor += 8 - hdr.nodeAddr.store(dst[cursor:]) + hdr.NodeAddr.store(dst[cursor:]) } -func (hdr *memdbVlogHdr) load(src []byte) { +func (hdr *MemdbVlogHdr) load(src []byte) { cursor := 0 - hdr.valueLen = endian.Uint32(src[cursor:]) + hdr.ValueLen = endian.Uint32(src[cursor:]) cursor += 4 - hdr.oldValue.load(src[cursor:]) + hdr.OldValue.load(src[cursor:]) cursor += 8 - hdr.nodeAddr.load(src[cursor:]) + hdr.NodeAddr.load(src[cursor:]) } -func (l *memdbVlog) appendValue(nodeAddr memdbArenaAddr, oldValue memdbArenaAddr, value []byte) memdbArenaAddr { +// AppendValue appends a value and it's vlog header to the vlog. +func (l *MemdbVlog[G, M]) AppendValue(nodeAddr MemdbArenaAddr, oldValue MemdbArenaAddr, value []byte) MemdbArenaAddr { size := memdbVlogHdrSize + len(value) prevBlocks := len(l.blocks) - addr, mem := l.alloc(size, false) + addr, mem := l.Alloc(size, false) copy(mem, value) - hdr := memdbVlogHdr{nodeAddr, oldValue, uint32(len(value))} + hdr := MemdbVlogHdr{nodeAddr, oldValue, uint32(len(value))} hdr.store(mem[len(value):]) addr.off += uint32(size) if prevBlocks != len(l.blocks) { - l.onMemChange() + l.OnMemChange() } return addr } -// A pure function that gets a value. -func (l *memdbVlog) getValue(addr memdbArenaAddr) []byte { +// GetValue is a pure function that gets a value. +func (l *MemdbVlog[G, M]) GetValue(addr MemdbArenaAddr) []byte { lenOff := addr.off - memdbVlogHdrSize block := l.blocks[addr.idx].buf valueLen := endian.Uint32(block[lenOff:]) if valueLen == 0 { - return tombstone + return Tombstone } valueOff := lenOff - valueLen return block[valueOff:lenOff:lenOff] } -func (l *memdbVlog) getSnapshotValue(addr memdbArenaAddr, snap *MemDBCheckpoint) ([]byte, bool) { - result := l.selectValueHistory(addr, func(addr memdbArenaAddr) bool { - return !l.canModify(snap, addr) +func (l *MemdbVlog[G, M]) GetSnapshotValue(addr MemdbArenaAddr, snap *MemDBCheckpoint) ([]byte, bool) { + result := l.SelectValueHistory(addr, func(addr MemdbArenaAddr) bool { + return !l.CanModify(snap, addr) }) - if result.isNull() { + if result.IsNull() { return nil, false } - return l.getValue(result), true + return l.GetValue(result), true } -func (l *memdbVlog) selectValueHistory(addr memdbArenaAddr, predicate func(memdbArenaAddr) bool) memdbArenaAddr { - for !addr.isNull() { +func (l *MemdbVlog[G, M]) SelectValueHistory(addr MemdbArenaAddr, predicate func(MemdbArenaAddr) bool) MemdbArenaAddr { + for !addr.IsNull() { if predicate(addr) { return addr } - var hdr memdbVlogHdr + var hdr MemdbVlogHdr hdr.load(l.blocks[addr.idx].buf[addr.off-memdbVlogHdrSize:]) - addr = hdr.oldValue + addr = hdr.OldValue } - return nullAddr + return NullAddr } -func (l *memdbVlog) revertToCheckpoint(db *MemDB, cp *MemDBCheckpoint) { - cursor := l.checkpoint() - for !cp.isSamePosition(&cursor) { +func (l *MemdbVlog[G, M]) RevertToCheckpoint(m M, cp *MemDBCheckpoint) { + cursor := l.Checkpoint() + for !cp.IsSamePosition(&cursor) { hdrOff := cursor.offsetInBlock - memdbVlogHdrSize block := l.blocks[cursor.blocks-1].buf - var hdr memdbVlogHdr + var hdr MemdbVlogHdr hdr.load(block[hdrOff:]) - node := db.getNode(hdr.nodeAddr) - - node.vptr = hdr.oldValue - db.size -= int(hdr.valueLen) - // oldValue.isNull() == true means this is a newly added value. - if hdr.oldValue.isNull() { - // If there are no flags associated with this key, we need to delete this node. - keptFlags := node.getKeyFlags().AndPersistent() - if keptFlags == 0 { - db.deleteNode(node) - } else { - node.setKeyFlags(keptFlags) - db.dirty = true - } - } else { - db.size += len(l.getValue(hdr.oldValue)) - } - + m.RevertNode(&hdr) l.moveBackCursor(&cursor, &hdr) } } -func (l *memdbVlog) inspectKVInLog(db *MemDB, head, tail *MemDBCheckpoint, f func([]byte, kv.KeyFlags, []byte)) { +func (l *MemdbVlog[G, M]) InspectKVInLog(m M, head, tail *MemDBCheckpoint, f func([]byte, kv.KeyFlags, []byte)) { cursor := *tail - for !head.isSamePosition(&cursor) { - cursorAddr := memdbArenaAddr{idx: uint32(cursor.blocks - 1), off: uint32(cursor.offsetInBlock)} + for !head.IsSamePosition(&cursor) { + cursorAddr := MemdbArenaAddr{idx: uint32(cursor.blocks - 1), off: uint32(cursor.offsetInBlock)} hdrOff := cursorAddr.off - memdbVlogHdrSize block := l.blocks[cursorAddr.idx].buf - var hdr memdbVlogHdr + var hdr MemdbVlogHdr hdr.load(block[hdrOff:]) - node := db.allocator.getNode(hdr.nodeAddr) + + node, vptr := m.InspectNode(hdr.NodeAddr) // Skip older versions. - if node.vptr == cursorAddr { - value := block[hdrOff-hdr.valueLen : hdrOff] - f(node.getKey(), node.getKeyFlags(), value) + if vptr == cursorAddr { + value := block[hdrOff-hdr.ValueLen : hdrOff] + f(node.GetKey(), node.GetKeyFlags(), value) } l.moveBackCursor(&cursor, &hdr) } } -func (l *memdbVlog) moveBackCursor(cursor *MemDBCheckpoint, hdr *memdbVlogHdr) { - cursor.offsetInBlock -= (memdbVlogHdrSize + int(hdr.valueLen)) +func (l *MemdbVlog[G, M]) moveBackCursor(cursor *MemDBCheckpoint, hdr *MemdbVlogHdr) { + cursor.offsetInBlock -= (memdbVlogHdrSize + int(hdr.ValueLen)) if cursor.offsetInBlock == 0 { cursor.blocks-- if cursor.blocks > 0 { @@ -417,7 +378,7 @@ func (l *memdbVlog) moveBackCursor(cursor *MemDBCheckpoint, hdr *memdbVlogHdr) { } } -func (l *memdbVlog) canModify(cp *MemDBCheckpoint, addr memdbArenaAddr) bool { +func (l *MemdbVlog[G, M]) CanModify(cp *MemDBCheckpoint, addr MemdbArenaAddr) bool { if cp == nil { return true } @@ -429,3 +390,15 @@ func (l *memdbVlog) canModify(cp *MemDBCheckpoint, addr memdbArenaAddr) bool { } return false } + +// MemKeyHandle represents a pointer for key in MemBuffer. +type MemKeyHandle struct { + // Opaque user data + UserData uint16 + idx uint16 + off uint32 +} + +func (h MemKeyHandle) ToAddr() MemdbArenaAddr { + return MemdbArenaAddr{idx: uint32(h.idx), off: h.off} +} diff --git a/internal/unionstore/arena/arena_test.go b/internal/unionstore/arena/arena_test.go new file mode 100644 index 0000000000..5816481bf1 --- /dev/null +++ b/internal/unionstore/arena/arena_test.go @@ -0,0 +1,79 @@ +// Copyright 2021 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// NOTE: The code in this file is based on code from the +// TiDB project, licensed under the Apache License v 2.0 +// +// https://github.com/pingcap/tidb/tree/cc5e161ac06827589c4966674597c137cc9e809c/store/tikv/unionstore/memdb_arena.go +// + +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package arena + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type dummyMemDB struct{} + +func (m *dummyMemDB) RevertNode(hdr *MemdbVlogHdr) {} +func (m *dummyMemDB) InspectNode(addr MemdbArenaAddr) (KeyFlagsGetter, MemdbArenaAddr) { + return nil, NullAddr +} + +func TestBigValue(t *testing.T) { + assert := assert.New(t) + + var vlog MemdbVlog[KeyFlagsGetter, *dummyMemDB] + vlog.AppendValue(MemdbArenaAddr{0, 0}, NullAddr, make([]byte, 80<<20)) + assert.Equal(vlog.blockSize, maxBlockSize) + assert.Equal(len(vlog.blocks), 1) + + cp := vlog.Checkpoint() + vlog.AppendValue(MemdbArenaAddr{0, 1}, NullAddr, make([]byte, 127<<20)) + vlog.RevertToCheckpoint(&dummyMemDB{}, &cp) + + assert.Equal(vlog.blockSize, maxBlockSize) + assert.Equal(len(vlog.blocks), 2) + assert.PanicsWithValue("alloc size is larger than max block size", func() { + vlog.AppendValue(MemdbArenaAddr{0, 2}, NullAddr, make([]byte, maxBlockSize+1)) + }) +} + +func TestValueLargeThanBlock(t *testing.T) { + assert := assert.New(t) + var vlog MemdbVlog[KeyFlagsGetter, *dummyMemDB] + vlog.AppendValue(MemdbArenaAddr{0, 0}, NullAddr, make([]byte, 1)) + vlog.AppendValue(MemdbArenaAddr{0, 1}, NullAddr, make([]byte, 4096)) + assert.Equal(len(vlog.blocks), 2) + vAddr := vlog.AppendValue(MemdbArenaAddr{0, 2}, NullAddr, make([]byte, 3000)) + assert.Equal(len(vlog.blocks), 2) + val := vlog.GetValue(vAddr) + assert.Equal(len(val), 3000) +} diff --git a/internal/unionstore/art/art.go b/internal/unionstore/art/art.go new file mode 100644 index 0000000000..4ed89e5f3c --- /dev/null +++ b/internal/unionstore/art/art.go @@ -0,0 +1,176 @@ +// Copyright 2024 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//nolint:unused +package art + +import ( + "math" + + "github.com/tikv/client-go/v2/internal/unionstore/arena" + "github.com/tikv/client-go/v2/kv" +) + +type ART struct { + allocator artAllocator + root artNode + stages []arena.MemDBCheckpoint + vlogInvalid bool + dirty bool + entrySizeLimit uint64 + bufferSizeLimit uint64 + len int + size int +} + +func New() *ART { + var t ART + t.root = nullArtNode + t.stages = make([]arena.MemDBCheckpoint, 0, 2) + t.entrySizeLimit = math.MaxUint64 + t.bufferSizeLimit = math.MaxUint64 + t.allocator.nodeAllocator.freeNode4 = make([]arena.MemdbArenaAddr, 0, 1<<4) + t.allocator.nodeAllocator.freeNode16 = make([]arena.MemdbArenaAddr, 0, 1<<3) + t.allocator.nodeAllocator.freeNode48 = make([]arena.MemdbArenaAddr, 0, 1<<2) + return &t +} + +func (t *ART) Get(key []byte) ([]byte, error) { + panic("unimplemented") +} + +// GetFlags returns the latest flags associated with key. +func (t *ART) GetFlags(key []byte) (kv.KeyFlags, error) { + panic("unimplemented") +} + +func (t *ART) Set(key artKey, value []byte, ops []kv.FlagsOp) error { + panic("unimplemented") +} + +func (t *ART) search(key artKey) (arena.MemdbArenaAddr, *artLeaf) { + panic("unimplemented") +} + +func (t *ART) Dirty() bool { + panic("unimplemented") +} + +// Mem returns the memory usage of MemBuffer. +func (t *ART) Mem() uint64 { + panic("unimplemented") +} + +// Len returns the count of entries in the MemBuffer. +func (t *ART) Len() int { + panic("unimplemented") +} + +// Size returns the size of the MemBuffer. +func (t *ART) Size() int { + panic("unimplemented") +} + +func (t *ART) checkpoint() arena.MemDBCheckpoint { + panic("unimplemented") +} + +func (t *ART) RevertNode(hdr *arena.MemdbVlogHdr) { + panic("unimplemented") +} + +func (t *ART) InspectNode(addr arena.MemdbArenaAddr) (*artLeaf, arena.MemdbArenaAddr) { + panic("unimplemented") +} + +// Checkpoint returns a checkpoint of ART. +func (t *ART) Checkpoint() *arena.MemDBCheckpoint { + panic("unimplemented") +} + +// RevertToCheckpoint reverts the ART to the checkpoint. +func (t *ART) RevertToCheckpoint(cp *arena.MemDBCheckpoint) { + panic("unimplemented") +} + +func (t *ART) Stages() []arena.MemDBCheckpoint { + panic("unimplemented") +} + +func (t *ART) Staging() int { + panic("unimplemented") +} + +func (t *ART) Release(h int) { + panic("unimplemented") +} + +func (t *ART) Cleanup(h int) { + panic("unimplemented") +} + +func (t *ART) revertToCheckpoint(cp *arena.MemDBCheckpoint) { + panic("unimplemented") +} + +func (t *ART) moveBackCursor(cursor *arena.MemDBCheckpoint, hdr *arena.MemdbVlogHdr) { + panic("unimplemented") +} + +func (t *ART) truncate(snap *arena.MemDBCheckpoint) { + panic("unimplemented") +} + +// DiscardValues releases the memory used by all values. +// NOTE: any operation need value will panic after this function. +func (t *ART) DiscardValues() { + panic("unimplemented") +} + +// InspectStage used to inspect the value updates in the given stage. +func (t *ART) InspectStage(handle int, f func([]byte, kv.KeyFlags, []byte)) { + panic("unimplemented") +} + +// SelectValueHistory select the latest value which makes `predicate` returns true from the modification history. +func (t *ART) SelectValueHistory(key []byte, predicate func(value []byte) bool) ([]byte, error) { + panic("unimplemented") +} + +func (t *ART) SetMemoryFootprintChangeHook(fn func(uint64)) { + panic("unimplemented") +} + +// MemHookSet implements the MemBuffer interface. +func (t *ART) MemHookSet() bool { + panic("unimplemented") +} + +// GetKeyByHandle returns key by handle. +func (t *ART) GetKeyByHandle(handle arena.MemKeyHandle) []byte { + panic("unimplemented") +} + +// GetValueByHandle returns value by handle. +func (t *ART) GetValueByHandle(handle arena.MemKeyHandle) ([]byte, bool) { + panic("unimplemented") +} + +func (t *ART) SetEntrySizeLimit(entryLimit, bufferLimit uint64) { + panic("unimplemented") +} + +func (t *ART) RemoveFromBuffer(key []byte) { + panic("unimplemented") +} diff --git a/internal/unionstore/art/art_arena.go b/internal/unionstore/art/art_arena.go new file mode 100644 index 0000000000..8b29ef4409 --- /dev/null +++ b/internal/unionstore/art/art_arena.go @@ -0,0 +1,35 @@ +// Copyright 2024 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//nolint:unused +package art + +import ( + "github.com/tikv/client-go/v2/internal/unionstore/arena" +) + +// fixedSizeArena is a fixed size arena allocator. +// because the size of each type of node is fixed, the discarded nodes can be reused. +// reusing blocks reduces the memory pieces. +type nodeArena struct { + arena.MemdbArena + freeNode4 []arena.MemdbArenaAddr + freeNode16 []arena.MemdbArenaAddr + freeNode48 []arena.MemdbArenaAddr +} + +type artAllocator struct { + vlogAllocator arena.MemdbVlog[*artLeaf, *ART] + nodeAllocator nodeArena +} diff --git a/internal/unionstore/art/art_iterator.go b/internal/unionstore/art/art_iterator.go new file mode 100644 index 0000000000..5e032c4e9d --- /dev/null +++ b/internal/unionstore/art/art_iterator.go @@ -0,0 +1,31 @@ +// Copyright 2024 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package art + +func (*ART) Iter([]byte, []byte) (*Iterator, error) { + panic("unimplemented") +} + +func (*ART) IterReverse([]byte, []byte) (*Iterator, error) { + panic("unimplemented") +} + +type Iterator struct{} + +func (i *Iterator) Valid() bool { panic("unimplemented") } +func (i *Iterator) Key() []byte { panic("unimplemented") } +func (i *Iterator) Value() []byte { panic("unimplemented") } +func (i *Iterator) Next() error { panic("unimplemented") } +func (i *Iterator) Close() { panic("unimplemented") } diff --git a/internal/unionstore/art/art_node.go b/internal/unionstore/art/art_node.go new file mode 100644 index 0000000000..53c568fe6a --- /dev/null +++ b/internal/unionstore/art/art_node.go @@ -0,0 +1,58 @@ +// Copyright 2024 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//nolint:unused +package art + +import ( + "github.com/tikv/client-go/v2/internal/unionstore/arena" + "github.com/tikv/client-go/v2/kv" +) + +type artNodeKind uint16 + +const ( + typeARTInvalid artNodeKind = 0 + //nolint:unused + typeARTNode4 artNodeKind = 1 + typeARTNode16 artNodeKind = 2 + typeARTNode48 artNodeKind = 3 + typeARTNode256 artNodeKind = 4 + typeARTLeaf artNodeKind = 5 +) + +var nullArtNode = artNode{kind: typeARTInvalid, addr: arena.NullAddr} + +type artKey []byte + +type artNode struct { + kind artNodeKind + addr arena.MemdbArenaAddr +} + +type artLeaf struct { + vAddr arena.MemdbArenaAddr + klen uint16 + flags uint16 +} + +// GetKey gets the full key of the leaf +func (l *artLeaf) GetKey() []byte { + panic("unimplemented") +} + +// GetKeyFlags gets the flags of the leaf +func (l *artLeaf) GetKeyFlags() kv.KeyFlags { + panic("unimplemented") +} diff --git a/internal/unionstore/art/art_snapshot.go b/internal/unionstore/art/art_snapshot.go new file mode 100644 index 0000000000..8b714d9ba1 --- /dev/null +++ b/internal/unionstore/art/art_snapshot.go @@ -0,0 +1,43 @@ +// Copyright 2024 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package art + +import "context" + +func (*ART) SnapshotGetter() *SnapshotGetter { + panic("unimplemented") +} + +func (*ART) SnapshotIter([]byte, []byte) *SnapshotIter { + panic("unimplemented") +} + +func (*ART) SnapshotIterReverse([]byte, []byte) *SnapshotIter { + panic("unimplemented") +} + +type SnapshotGetter struct{} + +func (s *SnapshotGetter) Get(context.Context, []byte) ([]byte, error) { + panic("unimplemented") +} + +type SnapshotIter struct{} + +func (i *SnapshotIter) Valid() bool { panic("unimplemented") } +func (i *SnapshotIter) Key() []byte { panic("unimplemented") } +func (i *SnapshotIter) Value() []byte { panic("unimplemented") } +func (i *SnapshotIter) Next() error { panic("unimplemented") } +func (i *SnapshotIter) Close() { panic("unimplemented") } diff --git a/internal/unionstore/memdb.go b/internal/unionstore/memdb.go index 4647504321..8284c202b9 100644 --- a/internal/unionstore/memdb.go +++ b/internal/unionstore/memdb.go @@ -35,919 +35,21 @@ package unionstore import ( - "bytes" - "fmt" "math" - "sync" - "sync/atomic" - "unsafe" - tikverr "github.com/tikv/client-go/v2/error" - "github.com/tikv/client-go/v2/kv" + "github.com/tikv/client-go/v2/internal/unionstore/arena" ) -var tombstone = []byte{} +const unlimitedSize = math.MaxUint64 // IsTombstone returns whether the value is a tombstone. func IsTombstone(val []byte) bool { return len(val) == 0 } -// MemKeyHandle represents a pointer for key in MemBuffer. -type MemKeyHandle struct { - // Opaque user data - UserData uint16 - idx uint16 - off uint32 -} - -func (h MemKeyHandle) toAddr() memdbArenaAddr { - return memdbArenaAddr{idx: uint32(h.idx), off: h.off} -} - -// MemDB is rollbackable Red-Black Tree optimized for TiDB's transaction states buffer use scenario. -// You can think MemDB is a combination of two separate tree map, one for key => value and another for key => keyFlags. -// -// The value map is rollbackable, that means you can use the `Staging`, `Release` and `Cleanup` API to safely modify KVs. -// -// The flags map is not rollbackable. There are two types of flag, persistent and non-persistent. -// When discarding a newly added KV in `Cleanup`, the non-persistent flags will be cleared. -// If there are persistent flags associated with key, we will keep this key in node without value. -type MemDB struct { - // This RWMutex only used to ensure memdbSnapGetter.Get will not race with - // concurrent memdb.Set, memdb.SetWithFlags, memdb.Delete and memdb.UpdateFlags. - sync.RWMutex - root memdbArenaAddr - allocator nodeAllocator - vlog memdbVlog - - entrySizeLimit uint64 - bufferSizeLimit uint64 - count int - size int - - vlogInvalid bool - dirty bool - stages []MemDBCheckpoint - // when the MemDB is wrapper by upper RWMutex, we can skip the internal mutex. - skipMutex bool - - // The lastTraversedNode must exist - lastTraversedNode atomic.Pointer[memdbNodeAddr] - hitCount atomic.Uint64 - missCount atomic.Uint64 -} - -const unlimitedSize = math.MaxUint64 - -func newMemDB() *MemDB { - db := new(MemDB) - db.allocator.init() - db.root = nullAddr - db.stages = make([]MemDBCheckpoint, 0, 2) - db.entrySizeLimit = unlimitedSize - db.bufferSizeLimit = unlimitedSize - db.vlog.memdb = db - db.skipMutex = false - db.lastTraversedNode.Store(&nullNodeAddr) - return db -} - -// updateLastTraversed updates the last traversed node atomically -func (db *MemDB) updateLastTraversed(node memdbNodeAddr) { - db.lastTraversedNode.Store(&node) -} - -// checkKeyInCache retrieves the last traversed node if the key matches -func (db *MemDB) checkKeyInCache(key []byte) (memdbNodeAddr, bool) { - nodePtr := db.lastTraversedNode.Load() - if nodePtr == nil || nodePtr.isNull() { - return nullNodeAddr, false - } - - if bytes.Equal(key, nodePtr.memdbNode.getKey()) { - return *nodePtr, true - } - - return nullNodeAddr, false -} - -// Staging create a new staging buffer inside the MemBuffer. -// Subsequent writes will be temporarily stored in this new staging buffer. -// When you think all modifications looks good, you can call `Release` to public all of them to the upper level buffer. -func (db *MemDB) Staging() int { - if !db.skipMutex { - db.Lock() - defer db.Unlock() - } - - db.stages = append(db.stages, db.vlog.checkpoint()) - return len(db.stages) -} - -// Release publish all modifications in the latest staging buffer to upper level. -func (db *MemDB) Release(h int) { - if !db.skipMutex { - db.Lock() - defer db.Unlock() - } - - if h != len(db.stages) { - // This should never happens in production environment. - // Use panic to make debug easier. - panic("cannot release staging buffer") - } - - if h == 1 { - tail := db.vlog.checkpoint() - if !db.stages[0].isSamePosition(&tail) { - db.dirty = true - } - } - db.stages = db.stages[:h-1] -} - -// Cleanup cleanup the resources referenced by the StagingHandle. -// If the changes are not published by `Release`, they will be discarded. -func (db *MemDB) Cleanup(h int) { - if !db.skipMutex { - db.Lock() - defer db.Unlock() - } - - if h > len(db.stages) { - return - } - if h < len(db.stages) { - // This should never happens in production environment. - // Use panic to make debug easier. - panic(fmt.Sprintf("cannot cleanup staging buffer, h=%v, len(db.stages)=%v", h, len(db.stages))) - } - - cp := &db.stages[h-1] - if !db.vlogInvalid { - curr := db.vlog.checkpoint() - if !curr.isSamePosition(cp) { - db.vlog.revertToCheckpoint(db, cp) - db.vlog.truncate(cp) - } - } - db.stages = db.stages[:h-1] - db.vlog.onMemChange() -} - -// Checkpoint returns a checkpoint of MemDB. -func (db *MemDB) Checkpoint() *MemDBCheckpoint { - cp := db.vlog.checkpoint() - return &cp -} - -// RevertToCheckpoint reverts the MemDB to the checkpoint. -func (db *MemDB) RevertToCheckpoint(cp *MemDBCheckpoint) { - db.vlog.revertToCheckpoint(db, cp) - db.vlog.truncate(cp) - db.vlog.onMemChange() -} - -// Reset resets the MemBuffer to initial states. -func (db *MemDB) Reset() { - db.root = nullAddr - db.stages = db.stages[:0] - db.dirty = false - db.vlogInvalid = false - db.size = 0 - db.count = 0 - db.vlog.reset() - db.allocator.reset() -} - -// DiscardValues releases the memory used by all values. -// NOTE: any operation need value will panic after this function. -func (db *MemDB) DiscardValues() { - db.vlogInvalid = true - db.vlog.reset() -} - -// InspectStage used to inspect the value updates in the given stage. -func (db *MemDB) InspectStage(handle int, f func([]byte, kv.KeyFlags, []byte)) { - idx := handle - 1 - tail := db.vlog.checkpoint() - head := db.stages[idx] - db.vlog.inspectKVInLog(db, &head, &tail, f) -} - -// Get gets the value for key k from kv store. -// If corresponding kv pair does not exist, it returns nil and ErrNotExist. -func (db *MemDB) Get(key []byte) ([]byte, error) { - if db.vlogInvalid { - // panic for easier debugging. - panic("vlog is resetted") - } - - x := db.traverse(key, false) - if x.isNull() { - return nil, tikverr.ErrNotExist - } - if x.vptr.isNull() { - // A flag only key, act as value not exists - return nil, tikverr.ErrNotExist - } - return db.vlog.getValue(x.vptr), nil -} - -// SelectValueHistory select the latest value which makes `predicate` returns true from the modification history. -func (db *MemDB) SelectValueHistory(key []byte, predicate func(value []byte) bool) ([]byte, error) { - x := db.traverse(key, false) - if x.isNull() { - return nil, tikverr.ErrNotExist - } - if x.vptr.isNull() { - // A flag only key, act as value not exists - return nil, tikverr.ErrNotExist - } - result := db.vlog.selectValueHistory(x.vptr, func(addr memdbArenaAddr) bool { - return predicate(db.vlog.getValue(addr)) - }) - if result.isNull() { - return nil, nil - } - return db.vlog.getValue(result), nil -} - -// GetFlags returns the latest flags associated with key. -func (db *MemDB) GetFlags(key []byte) (kv.KeyFlags, error) { - x := db.traverse(key, false) - if x.isNull() { - return 0, tikverr.ErrNotExist - } - return x.getKeyFlags(), nil -} - -// UpdateFlags update the flags associated with key. -func (db *MemDB) UpdateFlags(key []byte, ops ...kv.FlagsOp) { - err := db.set(key, nil, ops...) - _ = err // set without value will never fail -} - -// Set sets the value for key k as v into kv store. -// v must NOT be nil or empty, otherwise it returns ErrCannotSetNilValue. -func (db *MemDB) Set(key []byte, value []byte) error { - if len(value) == 0 { - return tikverr.ErrCannotSetNilValue - } - return db.set(key, value) -} - -// SetWithFlags put key-value into the last active staging buffer with the given KeyFlags. -func (db *MemDB) SetWithFlags(key []byte, value []byte, ops ...kv.FlagsOp) error { - if len(value) == 0 { - return tikverr.ErrCannotSetNilValue - } - return db.set(key, value, ops...) -} - -// Delete removes the entry for key k from kv store. -func (db *MemDB) Delete(key []byte) error { - return db.set(key, tombstone) -} - -// DeleteWithFlags delete key with the given KeyFlags -func (db *MemDB) DeleteWithFlags(key []byte, ops ...kv.FlagsOp) error { - return db.set(key, tombstone, ops...) -} - -// GetKeyByHandle returns key by handle. -func (db *MemDB) GetKeyByHandle(handle MemKeyHandle) []byte { - x := db.getNode(handle.toAddr()) - return x.getKey() -} - -// GetValueByHandle returns value by handle. -func (db *MemDB) GetValueByHandle(handle MemKeyHandle) ([]byte, bool) { - if db.vlogInvalid { - return nil, false - } - x := db.getNode(handle.toAddr()) - if x.vptr.isNull() { - return nil, false - } - return db.vlog.getValue(x.vptr), true -} - -// Len returns the number of entries in the DB. -func (db *MemDB) Len() int { - return db.count -} - -// Size returns sum of keys and values length. -func (db *MemDB) Size() int { - return db.size -} - -// Dirty returns whether the root staging buffer is updated. -func (db *MemDB) Dirty() bool { - return db.dirty -} - -func (db *MemDB) set(key []byte, value []byte, ops ...kv.FlagsOp) error { - if !db.skipMutex { - db.Lock() - defer db.Unlock() - } - - if db.vlogInvalid { - // panic for easier debugging. - panic("vlog is reset") - } - - if value != nil { - if size := uint64(len(key) + len(value)); size > db.entrySizeLimit { - return &tikverr.ErrEntryTooLarge{ - Limit: db.entrySizeLimit, - Size: size, - } - } - } - - if len(db.stages) == 0 { - db.dirty = true - } - x := db.traverse(key, true) - - // the NeedConstraintCheckInPrewrite flag is temporary, - // every write to the node removes the flag unless it's explicitly set. - // This set must be in the latest stage so no special processing is needed. - var flags kv.KeyFlags - if value != nil { - flags = kv.ApplyFlagsOps(x.getKeyFlags(), append([]kv.FlagsOp{kv.DelNeedConstraintCheckInPrewrite}, ops...)...) - } else { - // an UpdateFlag operation, do not delete the NeedConstraintCheckInPrewrite flag. - flags = kv.ApplyFlagsOps(x.getKeyFlags(), ops...) - } - if flags.AndPersistent() != 0 { - db.dirty = true - } - x.setKeyFlags(flags) - - if value == nil { - return nil - } - - db.setValue(x, value) - if uint64(db.Size()) > db.bufferSizeLimit { - return &tikverr.ErrTxnTooLarge{Size: db.Size()} - } - return nil -} - -func (db *MemDB) setValue(x memdbNodeAddr, value []byte) { - var activeCp *MemDBCheckpoint - if len(db.stages) > 0 { - activeCp = &db.stages[len(db.stages)-1] - } - - var oldVal []byte - if !x.vptr.isNull() { - oldVal = db.vlog.getValue(x.vptr) - } - - if len(oldVal) > 0 && db.vlog.canModify(activeCp, x.vptr) { - // For easier to implement, we only consider this case. - // It is the most common usage in TiDB's transaction buffers. - if len(oldVal) == len(value) { - copy(oldVal, value) - return - } - } - x.vptr = db.vlog.appendValue(x.addr, x.vptr, value) - db.size = db.size - len(oldVal) + len(value) -} - -// traverse search for and if not found and insert is true, will add a new node in. -// Returns a pointer to the new node, or the node found. -func (db *MemDB) traverse(key []byte, insert bool) memdbNodeAddr { - if node, found := db.checkKeyInCache(key); found { - db.hitCount.Add(1) - return node - } - db.missCount.Add(1) - - x := db.getRoot() - y := memdbNodeAddr{nil, nullAddr} - found := false - - // walk x down the tree - for !x.isNull() && !found { - cmp := bytes.Compare(key, x.getKey()) - if cmp < 0 { - if insert && x.left.isNull() { - y = x - } - x = x.getLeft(db) - } else if cmp > 0 { - if insert && x.right.isNull() { - y = x - } - x = x.getRight(db) - } else { - found = true - } - } - - if found { - db.updateLastTraversed(x) - } - - if found || !insert { - return x - } - - z := db.allocNode(key) - z.up = y.addr - - if y.isNull() { - db.root = z.addr - } else { - cmp := bytes.Compare(z.getKey(), y.getKey()) - if cmp < 0 { - y.left = z.addr - } else { - y.right = z.addr - } - } - - z.left = nullAddr - z.right = nullAddr - - // colour this new node red - z.setRed() - - // Having added a red node, we must now walk back up the tree balancing it, - // by a series of rotations and changing of colours - x = z - - // While we are not at the top and our parent node is red - // NOTE: Since the root node is guaranteed black, then we - // are also going to stop if we are the child of the root - - for x.addr != db.root { - xUp := x.getUp(db) - if xUp.isBlack() { - break - } - - xUpUp := xUp.getUp(db) - // if our parent is on the left side of our grandparent - if x.up == xUpUp.left { - // get the right side of our grandparent (uncle?) - y = xUpUp.getRight(db) - if y.isRed() { - // make our parent black - xUp.setBlack() - // make our uncle black - y.setBlack() - // make our grandparent red - xUpUp.setRed() - // now consider our grandparent - x = xUp.getUp(db) - } else { - // if we are on the right side of our parent - if x.addr == xUp.right { - // Move up to our parent - x = x.getUp(db) - db.leftRotate(x) - xUp = x.getUp(db) - xUpUp = xUp.getUp(db) - } - - xUp.setBlack() - xUpUp.setRed() - db.rightRotate(xUpUp) - } - } else { - // everything here is the same as above, but exchanging left for right - y = xUpUp.getLeft(db) - if y.isRed() { - xUp.setBlack() - y.setBlack() - xUpUp.setRed() - - x = xUp.getUp(db) - } else { - if x.addr == xUp.left { - x = x.getUp(db) - db.rightRotate(x) - xUp = x.getUp(db) - xUpUp = xUp.getUp(db) - } - - xUp.setBlack() - xUpUp.setRed() - db.leftRotate(xUpUp) - } - } - } - - // Set the root node black - db.getRoot().setBlack() - - db.updateLastTraversed(z) - - return z -} - -// -// Rotate our tree thus:- -// -// X leftRotate(X)---> Y -// / \ / \ -// A Y <---rightRotate(Y) X C -// / \ / \ -// B C A B -// -// NOTE: This does not change the ordering. -// -// We assume that neither X nor Y is NULL -// - -func (db *MemDB) leftRotate(x memdbNodeAddr) { - y := x.getRight(db) - - // Turn Y's left subtree into X's right subtree (move B) - x.right = y.left - - // If B is not null, set it's parent to be X - if !y.left.isNull() { - left := y.getLeft(db) - left.up = x.addr - } - - // Set Y's parent to be what X's parent was - y.up = x.up - - // if X was the root - if x.up.isNull() { - db.root = y.addr - } else { - xUp := x.getUp(db) - // Set X's parent's left or right pointer to be Y - if x.addr == xUp.left { - xUp.left = y.addr - } else { - xUp.right = y.addr - } - } - - // Put X on Y's left - y.left = x.addr - // Set X's parent to be Y - x.up = y.addr -} - -func (db *MemDB) rightRotate(y memdbNodeAddr) { - x := y.getLeft(db) - - // Turn X's right subtree into Y's left subtree (move B) - y.left = x.right - - // If B is not null, set it's parent to be Y - if !x.right.isNull() { - right := x.getRight(db) - right.up = y.addr - } - - // Set X's parent to be what Y's parent was - x.up = y.up - - // if Y was the root - if y.up.isNull() { - db.root = x.addr - } else { - yUp := y.getUp(db) - // Set Y's parent's left or right pointer to be X - if y.addr == yUp.left { - yUp.left = x.addr - } else { - yUp.right = x.addr - } - } - - // Put Y on X's right - x.right = y.addr - // Set Y's parent to be X - y.up = x.addr -} - -func (db *MemDB) deleteNode(z memdbNodeAddr) { - var x, y memdbNodeAddr - if db.lastTraversedNode.Load().addr == z.addr { - db.lastTraversedNode.Store(&nullNodeAddr) - } - - db.count-- - db.size -= int(z.klen) - - if z.left.isNull() || z.right.isNull() { - y = z - } else { - y = db.successor(z) - } - - if !y.left.isNull() { - x = y.getLeft(db) - } else { - x = y.getRight(db) - } - x.up = y.up - - if y.up.isNull() { - db.root = x.addr - } else { - yUp := y.getUp(db) - if y.addr == yUp.left { - yUp.left = x.addr - } else { - yUp.right = x.addr - } - } - - needFix := y.isBlack() - - // NOTE: traditional red-black tree will copy key from Y to Z and free Y. - // We cannot do the same thing here, due to Y's pointer is stored in vlog and the space in Z may not suitable for Y. - // So we need to copy states from Z to Y, and relink all nodes formerly connected to Z. - if y != z { - db.replaceNode(z, y) - } - - if needFix { - db.deleteNodeFix(x) - } - - db.allocator.freeNode(z.addr) -} - -func (db *MemDB) replaceNode(old memdbNodeAddr, new memdbNodeAddr) { - if !old.up.isNull() { - oldUp := old.getUp(db) - if old.addr == oldUp.left { - oldUp.left = new.addr - } else { - oldUp.right = new.addr - } - } else { - db.root = new.addr - } - new.up = old.up - - left := old.getLeft(db) - left.up = new.addr - new.left = old.left - - right := old.getRight(db) - right.up = new.addr - new.right = old.right - - if old.isBlack() { - new.setBlack() - } else { - new.setRed() - } -} - -func (db *MemDB) deleteNodeFix(x memdbNodeAddr) { - for x.addr != db.root && x.isBlack() { - xUp := x.getUp(db) - if x.addr == xUp.left { - w := xUp.getRight(db) - if w.isRed() { - w.setBlack() - xUp.setRed() - db.leftRotate(xUp) - w = x.getUp(db).getRight(db) - } - - if w.getLeft(db).isBlack() && w.getRight(db).isBlack() { - w.setRed() - x = x.getUp(db) - } else { - if w.getRight(db).isBlack() { - w.getLeft(db).setBlack() - w.setRed() - db.rightRotate(w) - w = x.getUp(db).getRight(db) - } - - xUp := x.getUp(db) - if xUp.isBlack() { - w.setBlack() - } else { - w.setRed() - } - xUp.setBlack() - w.getRight(db).setBlack() - db.leftRotate(xUp) - x = db.getRoot() - } - } else { - w := xUp.getLeft(db) - if w.isRed() { - w.setBlack() - xUp.setRed() - db.rightRotate(xUp) - w = x.getUp(db).getLeft(db) - } - - if w.getRight(db).isBlack() && w.getLeft(db).isBlack() { - w.setRed() - x = x.getUp(db) - } else { - if w.getLeft(db).isBlack() { - w.getRight(db).setBlack() - w.setRed() - db.leftRotate(w) - w = x.getUp(db).getLeft(db) - } - - xUp := x.getUp(db) - if xUp.isBlack() { - w.setBlack() - } else { - w.setRed() - } - xUp.setBlack() - w.getLeft(db).setBlack() - db.rightRotate(xUp) - x = db.getRoot() - } - } - } - x.setBlack() -} - -func (db *MemDB) successor(x memdbNodeAddr) (y memdbNodeAddr) { - if !x.right.isNull() { - // If right is not NULL then go right one and - // then keep going left until we find a node with - // no left pointer. - - y = x.getRight(db) - for !y.left.isNull() { - y = y.getLeft(db) - } - return - } - - // Go up the tree until we get to a node that is on the - // left of its parent (or the root) and then return the - // parent. - - y = x.getUp(db) - for !y.isNull() && x.addr == y.right { - x = y - y = y.getUp(db) - } - return y -} - -func (db *MemDB) predecessor(x memdbNodeAddr) (y memdbNodeAddr) { - if !x.left.isNull() { - // If left is not NULL then go left one and - // then keep going right until we find a node with - // no right pointer. - - y = x.getLeft(db) - for !y.right.isNull() { - y = y.getRight(db) - } - return - } - - // Go up the tree until we get to a node that is on the - // right of its parent (or the root) and then return the - // parent. - - y = x.getUp(db) - for !y.isNull() && x.addr == y.left { - x = y - y = y.getUp(db) - } - return y -} - -func (db *MemDB) getNode(x memdbArenaAddr) memdbNodeAddr { - return memdbNodeAddr{db.allocator.getNode(x), x} -} - -func (db *MemDB) getRoot() memdbNodeAddr { - return db.getNode(db.root) -} - -func (db *MemDB) allocNode(key []byte) memdbNodeAddr { - db.size += len(key) - db.count++ - x, xn := db.allocator.allocNode(key) - return memdbNodeAddr{xn, x} -} - -type memdbNodeAddr struct { - *memdbNode - addr memdbArenaAddr -} - -func (a *memdbNodeAddr) isNull() bool { - return a.addr.isNull() -} - -func (a memdbNodeAddr) getUp(db *MemDB) memdbNodeAddr { - return db.getNode(a.up) -} - -func (a memdbNodeAddr) getLeft(db *MemDB) memdbNodeAddr { - return db.getNode(a.left) -} - -func (a memdbNodeAddr) getRight(db *MemDB) memdbNodeAddr { - return db.getNode(a.right) -} - -type memdbNode struct { - up memdbArenaAddr - left memdbArenaAddr - right memdbArenaAddr - vptr memdbArenaAddr - klen uint16 - flags uint16 -} - -func (n *memdbNode) isRed() bool { - return n.flags&nodeColorBit != 0 -} - -func (n *memdbNode) isBlack() bool { - return !n.isRed() -} - -func (n *memdbNode) setRed() { - n.flags |= nodeColorBit -} - -func (n *memdbNode) setBlack() { - n.flags &= ^nodeColorBit -} - -func (n *memdbNode) getKey() []byte { - base := unsafe.Add(unsafe.Pointer(&n.flags), kv.FlagBytes) - return unsafe.Slice((*byte)(base), int(n.klen)) -} - -const ( - // bit 1 => red, bit 0 => black - nodeColorBit uint16 = 0x8000 - nodeFlagsMask = ^nodeColorBit -) - -func (n *memdbNode) getKeyFlags() kv.KeyFlags { - return kv.KeyFlags(n.flags & nodeFlagsMask) -} - -func (n *memdbNode) setKeyFlags(f kv.KeyFlags) { - n.flags = (^nodeFlagsMask & n.flags) | uint16(f) -} - -// RemoveFromBuffer removes a record from the mem buffer. It should be only used for test. -func (db *MemDB) RemoveFromBuffer(key []byte) { - x := db.traverse(key, false) - if x.isNull() { - return - } - db.size -= len(db.vlog.getValue(x.vptr)) - db.deleteNode(x) -} - -// SetMemoryFootprintChangeHook sets the hook function that is triggered when memdb grows. -func (db *MemDB) SetMemoryFootprintChangeHook(hook func(uint64)) { - innerHook := func() { - hook(db.allocator.capacity + db.vlog.capacity) - } - db.allocator.memChangeHook.Store(&innerHook) - db.vlog.memChangeHook.Store(&innerHook) -} - -// Mem returns the current memory footprint -func (db *MemDB) Mem() uint64 { - return db.allocator.capacity + db.vlog.capacity -} +type MemDBCheckpoint = arena.MemDBCheckpoint -// SetEntrySizeLimit sets the size limit for each entry and total buffer. -func (db *MemDB) SetEntrySizeLimit(entryLimit, bufferLimit uint64) { - db.entrySizeLimit = entryLimit - db.bufferSizeLimit = bufferLimit -} +type MemKeyHandle = arena.MemKeyHandle -func (db *MemDB) setSkipMutex(skip bool) { - db.skipMutex = skip -} +type MemDB = rbtDBWithContext -// MemHookSet implements the MemBuffer interface. -func (db *MemDB) MemHookSet() bool { - return db.allocator.memChangeHook.Load() != nil -} +var NewMemDB = newRbtDBWithContext +var NewMemDBWithContext = newRbtDBWithContext diff --git a/internal/unionstore/memdb_art.go b/internal/unionstore/memdb_art.go new file mode 100644 index 0000000000..d82610cd9a --- /dev/null +++ b/internal/unionstore/memdb_art.go @@ -0,0 +1,169 @@ +// Copyright 2024 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//nolint:unused +package unionstore + +import ( + "context" + "sync" + + tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/internal/unionstore/arena" + "github.com/tikv/client-go/v2/internal/unionstore/art" + "github.com/tikv/client-go/v2/kv" +) + +// artDBWithContext wraps ART to satisfy the MemBuffer interface. +type artDBWithContext struct { + // This RWMutex only used to ensure rbtSnapGetter.Get will not race with + // concurrent MemBuffer.Set, MemBuffer.SetWithFlags, MemBuffer.Delete and MemBuffer.UpdateFlags. + sync.RWMutex + *art.ART + + // when the ART is wrapper by upper RWMutex, we can skip the internal mutex. + skipMutex bool +} + +//nolint:unused +func newArtDBWithContext() *artDBWithContext { + return &artDBWithContext{ART: art.New()} +} + +func (db *artDBWithContext) setSkipMutex(skip bool) { + db.skipMutex = skip +} + +func (db *artDBWithContext) set(key, value []byte, ops []kv.FlagsOp) error { + if !db.skipMutex { + db.Lock() + defer db.Unlock() + } + return db.ART.Set(key, value, ops) +} + +func (db *artDBWithContext) Set(key, value []byte) error { + if len(value) == 0 { + return tikverr.ErrCannotSetNilValue + } + return db.set(key, value, nil) +} + +// SetWithFlags put key-value into the last active staging buffer with the given KeyFlags. +func (db *artDBWithContext) SetWithFlags(key []byte, value []byte, ops ...kv.FlagsOp) error { + if len(value) == 0 { + return tikverr.ErrCannotSetNilValue + } + return db.set(key, value, ops) +} + +func (db *artDBWithContext) UpdateFlags(key []byte, ops ...kv.FlagsOp) { + _ = db.set(key, nil, ops) +} + +func (db *artDBWithContext) Delete(key []byte) error { + return db.set(key, arena.Tombstone, nil) +} + +func (db *artDBWithContext) DeleteWithFlags(key []byte, ops ...kv.FlagsOp) error { + return db.set(key, arena.Tombstone, ops) +} + +func (db *artDBWithContext) Staging() int { + if !db.skipMutex { + db.Lock() + defer db.Unlock() + } + return db.ART.Staging() +} + +func (db *artDBWithContext) Cleanup(handle int) { + if !db.skipMutex { + db.Lock() + defer db.Unlock() + } + db.ART.Cleanup(handle) +} + +func (db *artDBWithContext) Release(handle int) { + if !db.skipMutex { + db.Lock() + defer db.Unlock() + } + db.ART.Release(handle) +} + +func (db *artDBWithContext) Get(_ context.Context, k []byte) ([]byte, error) { + return db.ART.Get(k) +} + +func (db *artDBWithContext) GetLocal(_ context.Context, k []byte) ([]byte, error) { + return db.ART.Get(k) +} + +func (db *artDBWithContext) Flush(bool) (bool, error) { return false, nil } + +func (db *artDBWithContext) FlushWait() error { return nil } + +// GetMemDB implements the MemBuffer interface. +func (db *artDBWithContext) GetMemDB() *MemDB { + panic("unimplemented") +} + +// BatchGet returns the values for given keys from the MemBuffer. +func (db *artDBWithContext) BatchGet(ctx context.Context, keys [][]byte) (map[string][]byte, error) { + if db.Len() == 0 { + return map[string][]byte{}, nil + } + m := make(map[string][]byte, len(keys)) + for _, k := range keys { + v, err := db.Get(ctx, k) + if err != nil { + if tikverr.IsErrNotFound(err) { + continue + } + return nil, err + } + m[string(k)] = v + } + return m, nil +} + +// GetMetrics implements the MemBuffer interface. +func (db *artDBWithContext) GetMetrics() Metrics { return Metrics{} } + +// Iter implements the Retriever interface. +func (db *artDBWithContext) Iter(lower, upper []byte) (Iterator, error) { + return db.ART.Iter(lower, upper) +} + +// IterReverse implements the Retriever interface. +func (db *artDBWithContext) IterReverse(upper, lower []byte) (Iterator, error) { + return db.ART.IterReverse(upper, lower) +} + +// SnapshotIter returns an Iterator for a snapshot of MemBuffer. +func (db *artDBWithContext) SnapshotIter(lower, upper []byte) Iterator { + return db.ART.SnapshotIter(lower, upper) +} + +// SnapshotIterReverse returns a reversed Iterator for a snapshot of MemBuffer. +func (db *artDBWithContext) SnapshotIterReverse(upper, lower []byte) Iterator { + return db.ART.SnapshotIter(upper, lower) +} + +// SnapshotGetter returns a Getter for a snapshot of MemBuffer. +func (db *artDBWithContext) SnapshotGetter() Getter { + return db.ART.SnapshotGetter() +} diff --git a/internal/unionstore/memdb_bench_test.go b/internal/unionstore/memdb_bench_test.go index 3653b21c68..2b7e6e69a0 100644 --- a/internal/unionstore/memdb_bench_test.go +++ b/internal/unionstore/memdb_bench_test.go @@ -35,6 +35,7 @@ package unionstore import ( + "context" "encoding/binary" "math/rand" "testing" @@ -50,7 +51,7 @@ func BenchmarkLargeIndex(b *testing.B) { for i := range buf { binary.LittleEndian.PutUint32(buf[i][:], uint32(i)) } - db := newMemDB() + db := NewMemDB() b.ResetTimer() for i := range buf { @@ -64,7 +65,7 @@ func BenchmarkPut(b *testing.B) { binary.BigEndian.PutUint32(buf[i][:], uint32(i)) } - p := newMemDB() + p := NewMemDB() b.ResetTimer() for i := range buf { @@ -78,7 +79,7 @@ func BenchmarkPutRandom(b *testing.B) { binary.LittleEndian.PutUint32(buf[i][:], uint32(rand.Int())) } - p := newMemDB() + p := NewMemDB() b.ResetTimer() for i := range buf { @@ -92,14 +93,15 @@ func BenchmarkGet(b *testing.B) { binary.BigEndian.PutUint32(buf[i][:], uint32(i)) } - p := newMemDB() + p := NewMemDB() for i := range buf { p.Set(buf[i][:keySize], buf[i][:]) } + ctx := context.Background() b.ResetTimer() for i := range buf { - p.Get(buf[i][:keySize]) + p.Get(ctx, buf[i][:keySize]) } } @@ -109,14 +111,15 @@ func BenchmarkGetRandom(b *testing.B) { binary.LittleEndian.PutUint32(buf[i][:], uint32(rand.Int())) } - p := newMemDB() + p := NewMemDB() for i := range buf { p.Set(buf[i][:keySize], buf[i][:]) } + ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { - p.Get(buf[i][:keySize]) + p.Get(ctx, buf[i][:keySize]) } } @@ -127,7 +130,7 @@ func BenchmarkMemDbBufferSequential(b *testing.B) { for i := 0; i < opCnt; i++ { data[i] = encodeInt(i) } - buffer := newMemDB() + buffer := NewMemDB() benchmarkSetGet(b, buffer, data) b.ReportAllocs() } @@ -138,20 +141,20 @@ func BenchmarkMemDbBufferRandom(b *testing.B) { data[i] = encodeInt(i) } shuffle(data) - buffer := newMemDB() + buffer := NewMemDB() benchmarkSetGet(b, buffer, data) b.ReportAllocs() } func BenchmarkMemDbIter(b *testing.B) { - buffer := newMemDB() + buffer := NewMemDB() benchIterator(b, buffer) b.ReportAllocs() } func BenchmarkMemDbCreation(b *testing.B) { for i := 0; i < b.N; i++ { - newMemDB() + NewMemDB() } b.ReportAllocs() } @@ -165,13 +168,14 @@ func shuffle(slc [][]byte) { } } func benchmarkSetGet(b *testing.B, buffer *MemDB, data [][]byte) { + ctx := context.Background() b.ResetTimer() for i := 0; i < b.N; i++ { for _, k := range data { buffer.Set(k, k) } for _, k := range data { - buffer.Get(k) + buffer.Get(ctx, k) } } } diff --git a/internal/unionstore/memdb_norace_test.go b/internal/unionstore/memdb_norace_test.go index 25565d3fd8..c8d8e54fe4 100644 --- a/internal/unionstore/memdb_norace_test.go +++ b/internal/unionstore/memdb_norace_test.go @@ -59,7 +59,7 @@ func TestRandom(t *testing.T) { rand2.Read(keys[i]) } - p1 := newMemDB() + p1 := NewMemDB() p2 := leveldb.New(comparer.DefaultComparer, 4*1024) for _, k := range keys { p1.Set(k, k) @@ -88,7 +88,7 @@ func TestRandom(t *testing.T) { // The test takes too long under the race detector. func TestRandomDerive(t *testing.T) { - db := newMemDB() + db := NewMemDB() golden := leveldb.New(comparer.DefaultComparer, 4*1024) testRandomDeriveRecur(t, db, golden, 0) } diff --git a/internal/unionstore/memdb_rbt.go b/internal/unionstore/memdb_rbt.go new file mode 100644 index 0000000000..4c4e85c5f9 --- /dev/null +++ b/internal/unionstore/memdb_rbt.go @@ -0,0 +1,176 @@ +// Copyright 2024 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package unionstore + +import ( + "context" + "sync" + + tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/internal/unionstore/arena" + "github.com/tikv/client-go/v2/internal/unionstore/rbt" + "github.com/tikv/client-go/v2/kv" +) + +// rbtDBWithContext wraps RBT to satisfy the MemBuffer interface. +type rbtDBWithContext struct { + // This RWMutex only used to ensure rbtSnapGetter.Get will not race with + // concurrent MemBuffer.Set, MemBuffer.SetWithFlags, MemBuffer.Delete and MemBuffer.UpdateFlags. + sync.RWMutex + *rbt.RBT + + // when the RBT is wrapper by upper RWMutex, we can skip the internal mutex. + skipMutex bool +} + +func newRbtDBWithContext() *rbtDBWithContext { + return &rbtDBWithContext{ + skipMutex: false, + RBT: rbt.New(), + } +} + +func (db *rbtDBWithContext) setSkipMutex(skip bool) { + db.skipMutex = skip +} + +func (db *rbtDBWithContext) set(key, value []byte, ops ...kv.FlagsOp) error { + if !db.skipMutex { + db.Lock() + defer db.Unlock() + } + return db.RBT.Set(key, value, ops...) +} + +// UpdateFlags update the flags associated with key. +func (db *rbtDBWithContext) UpdateFlags(key []byte, ops ...kv.FlagsOp) { + err := db.set(key, nil, ops...) + _ = err // set without value will never fail +} + +// Set sets the value for key k as v into kv store. +// v must NOT be nil or empty, otherwise it returns ErrCannotSetNilValue. +func (db *rbtDBWithContext) Set(key []byte, value []byte) error { + if len(value) == 0 { + return tikverr.ErrCannotSetNilValue + } + return db.set(key, value) +} + +// SetWithFlags put key-value into the last active staging buffer with the given KeyFlags. +func (db *rbtDBWithContext) SetWithFlags(key []byte, value []byte, ops ...kv.FlagsOp) error { + if len(value) == 0 { + return tikverr.ErrCannotSetNilValue + } + return db.set(key, value, ops...) +} + +// Delete removes the entry for key k from kv store. +func (db *rbtDBWithContext) Delete(key []byte) error { + return db.set(key, arena.Tombstone) +} + +// DeleteWithFlags delete key with the given KeyFlags +func (db *rbtDBWithContext) DeleteWithFlags(key []byte, ops ...kv.FlagsOp) error { + return db.set(key, arena.Tombstone, ops...) +} + +func (db *rbtDBWithContext) Staging() int { + if !db.skipMutex { + db.Lock() + defer db.Unlock() + } + return db.RBT.Staging() +} + +func (db *rbtDBWithContext) Cleanup(handle int) { + if !db.skipMutex { + db.Lock() + defer db.Unlock() + } + db.RBT.Cleanup(handle) +} + +func (db *rbtDBWithContext) Release(handle int) { + if !db.skipMutex { + db.Lock() + defer db.Unlock() + } + db.RBT.Release(handle) +} + +func (db *rbtDBWithContext) Get(_ context.Context, k []byte) ([]byte, error) { + return db.RBT.Get(k) +} + +func (db *rbtDBWithContext) GetLocal(_ context.Context, k []byte) ([]byte, error) { + return db.RBT.Get(k) +} + +func (db *rbtDBWithContext) Flush(bool) (bool, error) { return false, nil } + +func (db *rbtDBWithContext) FlushWait() error { return nil } + +// GetMemDB implements the MemBuffer interface. +func (db *rbtDBWithContext) GetMemDB() *MemDB { + return db +} + +// BatchGet returns the values for given keys from the MemBuffer. +func (db *rbtDBWithContext) BatchGet(ctx context.Context, keys [][]byte) (map[string][]byte, error) { + if db.Len() == 0 { + return map[string][]byte{}, nil + } + m := make(map[string][]byte, len(keys)) + for _, k := range keys { + v, err := db.Get(ctx, k) + if err != nil { + if tikverr.IsErrNotFound(err) { + continue + } + return nil, err + } + m[string(k)] = v + } + return m, nil +} + +// GetMetrics implements the MemBuffer interface. +func (db *rbtDBWithContext) GetMetrics() Metrics { return Metrics{} } + +// Iter implements the Retriever interface. +func (db *rbtDBWithContext) Iter(lower, upper []byte) (Iterator, error) { + return db.RBT.Iter(lower, upper) +} + +// IterReverse implements the Retriever interface. +func (db *rbtDBWithContext) IterReverse(upper, lower []byte) (Iterator, error) { + return db.RBT.IterReverse(upper, lower) +} + +// SnapshotIter returns an Iterator for a snapshot of MemBuffer. +func (db *rbtDBWithContext) SnapshotIter(lower, upper []byte) Iterator { + return db.RBT.SnapshotIter(lower, upper) +} + +// SnapshotIterReverse returns a reversed Iterator for a snapshot of MemBuffer. +func (db *rbtDBWithContext) SnapshotIterReverse(upper, lower []byte) Iterator { + return db.RBT.SnapshotIter(upper, lower) +} + +// SnapshotGetter returns a Getter for a snapshot of MemBuffer. +func (db *rbtDBWithContext) SnapshotGetter() Getter { + return db.RBT.SnapshotGetter() +} diff --git a/internal/unionstore/memdb_test.go b/internal/unionstore/memdb_test.go index 0c2852e521..cb2776eb83 100644 --- a/internal/unionstore/memdb_test.go +++ b/internal/unionstore/memdb_test.go @@ -50,43 +50,33 @@ import ( type KeyFlags = kv.KeyFlags -func init() { - testMode = true +func TestGetSet(t *testing.T) { + testGetSet(t, newRbtDBWithContext()) } -func TestGetSet(t *testing.T) { +func testGetSet(t *testing.T, db MemBuffer) { require := require.New(t) const cnt = 10000 - p := fillDB(cnt) + fillDB(db, cnt) var buf [4]byte for i := 0; i < cnt; i++ { binary.BigEndian.PutUint32(buf[:], uint32(i)) - v, err := p.Get(buf[:]) + v, err := db.Get(context.Background(), buf[:]) require.Nil(err) require.Equal(v, buf[:]) } } -func TestBigKV(t *testing.T) { - assert := assert.New(t) - db := newMemDB() - db.Set([]byte{1}, make([]byte, 80<<20)) - assert.Equal(db.vlog.blockSize, maxBlockSize) - assert.Equal(len(db.vlog.blocks), 1) - h := db.Staging() - db.Set([]byte{2}, make([]byte, 127<<20)) - db.Release(h) - assert.Equal(db.vlog.blockSize, maxBlockSize) - assert.Equal(len(db.vlog.blocks), 2) - assert.PanicsWithValue("alloc size is larger than max block size", func() { db.Set([]byte{3}, make([]byte, maxBlockSize+1)) }) +func TestIterator(t *testing.T) { + testIterator(t, newRbtDBWithContext()) } -func TestIterator(t *testing.T) { +func testIterator(t *testing.T, db MemBuffer) { assert := assert.New(t) const cnt = 10000 - db := fillDB(cnt) + fillDB(db, cnt) var buf [4]byte var i int @@ -130,14 +120,17 @@ func TestIterator(t *testing.T) { } func TestDiscard(t *testing.T) { + testDiscard(t, newRbtDBWithContext()) +} + +func testDiscard(t *testing.T, db MemBuffer) { assert := assert.New(t) const cnt = 10000 - db := newMemDB() - base := deriveAndFill(0, cnt, 0, db) + base := deriveAndFill(db, 0, cnt, 0) sz := db.Size() - db.Cleanup(deriveAndFill(0, cnt, 1, db)) + db.Cleanup(deriveAndFill(db, 0, cnt, 1)) assert.Equal(db.Len(), cnt) assert.Equal(db.Size(), sz) @@ -145,7 +138,7 @@ func TestDiscard(t *testing.T) { for i := 0; i < cnt; i++ { binary.BigEndian.PutUint32(buf[:], uint32(i)) - v, err := db.Get(buf[:]) + v, err := db.Get(context.Background(), buf[:]) assert.Nil(err) assert.Equal(v, buf[:]) } @@ -171,28 +164,25 @@ func TestDiscard(t *testing.T) { db.Cleanup(base) for i := 0; i < cnt; i++ { binary.BigEndian.PutUint32(buf[:], uint32(i)) - _, err := db.Get(buf[:]) + _, err := db.Get(context.Background(), buf[:]) assert.NotNil(err) } - it1, _ := db.Iter(nil, nil) - it := it1.(*MemdbIterator) - it.seekToFirst() - assert.False(it.Valid()) - it.seekToLast() - assert.False(it.Valid()) - it.seek([]byte{0xff}) + it, _ := db.Iter(nil, nil) assert.False(it.Valid()) } func TestFlushOverwrite(t *testing.T) { + testFlushOverwrite(t, newRbtDBWithContext()) +} + +func testFlushOverwrite(t *testing.T, db MemBuffer) { assert := assert.New(t) const cnt = 10000 - db := newMemDB() - db.Release(deriveAndFill(0, cnt, 0, db)) + db.Release(deriveAndFill(db, 0, cnt, 0)) sz := db.Size() - db.Release(deriveAndFill(0, cnt, 1, db)) + db.Release(deriveAndFill(db, 0, cnt, 1)) assert.Equal(db.Len(), cnt) assert.Equal(db.Size(), sz) @@ -202,7 +192,7 @@ func TestFlushOverwrite(t *testing.T) { for i := 0; i < cnt; i++ { binary.BigEndian.PutUint32(kbuf[:], uint32(i)) binary.BigEndian.PutUint32(vbuf[:], uint32(i+1)) - v, err := db.Get(kbuf[:]) + v, err := db.Get(context.Background(), kbuf[:]) assert.Nil(err) assert.Equal(v, vbuf[:]) } @@ -229,6 +219,10 @@ func TestFlushOverwrite(t *testing.T) { } func TestComplexUpdate(t *testing.T) { + testComplexUpdate(t, newRbtDBWithContext()) +} + +func testComplexUpdate(t *testing.T, db MemBuffer) { assert := assert.New(t) const ( @@ -237,10 +231,9 @@ func TestComplexUpdate(t *testing.T) { insert = 9000 ) - db := newMemDB() - db.Release(deriveAndFill(0, overwrite, 0, db)) + db.Release(deriveAndFill(db, 0, overwrite, 0)) assert.Equal(db.Len(), overwrite) - db.Release(deriveAndFill(keep, insert, 1, db)) + db.Release(deriveAndFill(db, keep, insert, 1)) assert.Equal(db.Len(), insert) var kbuf, vbuf [4]byte @@ -251,20 +244,23 @@ func TestComplexUpdate(t *testing.T) { if i >= keep { binary.BigEndian.PutUint32(vbuf[:], uint32(i+1)) } - v, err := db.Get(kbuf[:]) + v, err := db.Get(context.Background(), kbuf[:]) assert.Nil(err) assert.Equal(v, vbuf[:]) } } func TestNestedSandbox(t *testing.T) { + testNestedSandbox(t, newRbtDBWithContext()) +} + +func testNestedSandbox(t *testing.T, db MemBuffer) { assert := assert.New(t) - db := newMemDB() - h0 := deriveAndFill(0, 200, 0, db) - h1 := deriveAndFill(0, 100, 1, db) - h2 := deriveAndFill(50, 150, 2, db) - h3 := deriveAndFill(100, 120, 3, db) - h4 := deriveAndFill(0, 150, 4, db) + h0 := deriveAndFill(db, 0, 200, 0) + h1 := deriveAndFill(db, 0, 100, 1) + h2 := deriveAndFill(db, 50, 150, 2) + h3 := deriveAndFill(db, 100, 120, 3) + h4 := deriveAndFill(db, 0, 150, 4) db.Cleanup(h4) // Discard (0..150 -> 4) db.Release(h3) // Flush (100..120 -> 3) db.Cleanup(h2) // Discard (100..120 -> 3) & (50..150 -> 2) @@ -280,7 +276,7 @@ func TestNestedSandbox(t *testing.T) { if i < 100 { binary.BigEndian.PutUint32(vbuf[:], uint32(i+1)) } - v, err := db.Get(kbuf[:]) + v, err := db.Get(context.Background(), kbuf[:]) assert.Nil(err) assert.Equal(v, vbuf[:]) } @@ -314,10 +310,14 @@ func TestNestedSandbox(t *testing.T) { } func TestOverwrite(t *testing.T) { + testOverwrite(t, newRbtDBWithContext()) +} + +func testOverwrite(t *testing.T, db MemBuffer) { assert := assert.New(t) const cnt = 10000 - db := fillDB(cnt) + fillDB(db, cnt) var buf [4]byte sz := db.Size() @@ -332,7 +332,7 @@ func TestOverwrite(t *testing.T) { for i := 0; i < cnt; i++ { binary.BigEndian.PutUint32(buf[:], uint32(i)) - val, _ := db.Get(buf[:]) + val, _ := db.Get(context.Background(), buf[:]) v := binary.BigEndian.Uint32(val) if i%3 == 0 { assert.Equal(v, uint32(i*10)) @@ -371,56 +371,32 @@ func TestOverwrite(t *testing.T) { assert.Equal(i, -1) } -func TestKVLargeThanBlock(t *testing.T) { - assert := assert.New(t) - db := newMemDB() - db.Set([]byte{1}, make([]byte, 1)) - db.Set([]byte{2}, make([]byte, 4096)) - assert.Equal(len(db.vlog.blocks), 2) - db.Set([]byte{3}, make([]byte, 3000)) - assert.Equal(len(db.vlog.blocks), 2) - val, err := db.Get([]byte{3}) - assert.Nil(err) - assert.Equal(len(val), 3000) -} - -func TestEmptyDB(t *testing.T) { - assert := assert.New(t) - db := newMemDB() - _, err := db.Get([]byte{0}) - assert.NotNil(err) - it1, _ := db.Iter(nil, nil) - it := it1.(*MemdbIterator) - it.seekToFirst() - assert.False(it.Valid()) - it.seekToLast() - assert.False(it.Valid()) - it.seek([]byte{0xff}) - assert.False(it.Valid()) +func TestReset(t *testing.T) { + testReset(t, newRbtDBWithContext()) } -func TestReset(t *testing.T) { +func testReset(t *testing.T, db interface { + MemBuffer + Reset() +}) { assert := assert.New(t) - db := fillDB(1000) + fillDB(db, 1000) db.Reset() - _, err := db.Get([]byte{0, 0, 0, 0}) + _, err := db.Get(context.Background(), []byte{0, 0, 0, 0}) assert.NotNil(err) - it1, _ := db.Iter(nil, nil) - it := it1.(*MemdbIterator) - it.seekToFirst() - assert.False(it.Valid()) - it.seekToLast() - assert.False(it.Valid()) - it.seek([]byte{0xff}) + it, _ := db.Iter(nil, nil) assert.False(it.Valid()) } func TestInspectStage(t *testing.T) { + testInspectStage(t, newRbtDBWithContext()) +} + +func testInspectStage(t *testing.T, db MemBuffer) { assert := assert.New(t) - db := newMemDB() - h1 := deriveAndFill(0, 1000, 0, db) - h2 := deriveAndFill(500, 1000, 1, db) + h1 := deriveAndFill(db, 0, 1000, 0) + h2 := deriveAndFill(db, 500, 1000, 1) for i := 500; i < 1500; i++ { var kbuf [4]byte // don't update in place @@ -429,7 +405,7 @@ func TestInspectStage(t *testing.T) { binary.BigEndian.PutUint32(vbuf[:], uint32(i+2)) db.Set(kbuf[:], vbuf[:]) } - h3 := deriveAndFill(1000, 2000, 3, db) + h3 := deriveAndFill(db, 1000, 2000, 3) db.InspectStage(h3, func(key []byte, _ KeyFlags, val []byte) { k := int(binary.BigEndian.Uint32(key)) @@ -470,13 +446,17 @@ func TestInspectStage(t *testing.T) { } func TestDirty(t *testing.T) { + testDirty(t, func() MemBuffer { return newRbtDBWithContext() }) +} + +func testDirty(t *testing.T, createDb func() MemBuffer) { assert := assert.New(t) - db := newMemDB() + db := createDb() db.Set([]byte{1}, []byte{1}) assert.True(db.Dirty()) - db = newMemDB() + db = createDb() h := db.Staging() db.Set([]byte{1}, []byte{1}) db.Cleanup(h) @@ -488,14 +468,14 @@ func TestDirty(t *testing.T) { assert.True(db.Dirty()) // persistent flags will make memdb dirty. - db = newMemDB() + db = createDb() h = db.Staging() db.SetWithFlags([]byte{1}, []byte{1}, kv.SetKeyLocked) db.Cleanup(h) assert.True(db.Dirty()) // non-persistent flags will not make memdb dirty. - db = newMemDB() + db = createDb() h = db.Staging() db.SetWithFlags([]byte{1}, []byte{1}, kv.SetPresumeKeyNotExists) db.Cleanup(h) @@ -503,10 +483,13 @@ func TestDirty(t *testing.T) { } func TestFlags(t *testing.T) { + testFlags(t, newRbtDBWithContext(), func(db MemBuffer) Iterator { return db.(*rbtDBWithContext).IterWithFlags(nil, nil) }) +} + +func testFlags(t *testing.T, db MemBuffer, iterWithFlags func(db MemBuffer) Iterator) { assert := assert.New(t) const cnt = 10000 - db := newMemDB() h := db.Staging() for i := uint32(0); i < cnt; i++ { var buf [4]byte @@ -522,7 +505,7 @@ func TestFlags(t *testing.T) { for i := uint32(0); i < cnt; i++ { var buf [4]byte binary.BigEndian.PutUint32(buf[:], i) - _, err := db.Get(buf[:]) + _, err := db.Get(context.Background(), buf[:]) assert.NotNil(err) flags, err := db.GetFlags(buf[:]) if i%2 == 0 { @@ -537,13 +520,10 @@ func TestFlags(t *testing.T) { assert.Equal(db.Len(), 5000) assert.Equal(db.Size(), 20000) - it1, _ := db.Iter(nil, nil) - it := it1.(*MemdbIterator) + it, _ := db.Iter(nil, nil) assert.False(it.Valid()) - it.includeFlags = true - it.init() - + it = iterWithFlags(db) for ; it.Valid(); it.Next() { k := binary.BigEndian.Uint32(it.Key()) assert.True(k%2 == 0) @@ -557,7 +537,7 @@ func TestFlags(t *testing.T) { for i := uint32(0); i < cnt; i++ { var buf [4]byte binary.BigEndian.PutUint32(buf[:], i) - _, err := db.Get(buf[:]) + _, err := db.Get(context.Background(), buf[:]) assert.NotNil(err) // UpdateFlags will create missing node. @@ -578,7 +558,7 @@ func checkConsist(t *testing.T, p1 *MemDB, p2 *leveldb.DB) { var prevKey, prevVal []byte for it2.First(); it2.Valid(); it2.Next() { - v, err := p1.Get(it2.Key()) + v, err := p1.Get(context.Background(), it2.Key()) assert.Nil(err) assert.Equal(v, it2.Value()) @@ -608,14 +588,12 @@ func checkConsist(t *testing.T, p1 *MemDB, p2 *leveldb.DB) { } } -func fillDB(cnt int) *MemDB { - db := newMemDB() - h := deriveAndFill(0, cnt, 0, db) +func fillDB(db MemBuffer, cnt int) { + h := deriveAndFill(db, 0, cnt, 0) db.Release(h) - return db } -func deriveAndFill(start, end, valueBase int, db *MemDB) int { +func deriveAndFill(db MemBuffer, start, end, valueBase int) int { h := db.Staging() var kbuf, vbuf [4]byte for i := start; i < end; i++ { @@ -704,21 +682,21 @@ func checkNewIterator(t *testing.T, buffer *MemDB) { func mustGet(t *testing.T, buffer *MemDB) { for i := startIndex; i < testCount; i++ { s := encodeInt(i * indexStep) - val, err := buffer.Get(s) + val, err := buffer.Get(context.Background(), s) assert.Nil(t, err) assert.Equal(t, string(val), string(s)) } } func TestKVGetSet(t *testing.T) { - buffer := newMemDB() + buffer := NewMemDB() insertData(t, buffer) mustGet(t, buffer) } func TestNewIterator(t *testing.T) { assert := assert.New(t) - buffer := newMemDB() + buffer := NewMemDB() // should be invalid iter, err := buffer.Iter(nil, nil) assert.Nil(err) @@ -746,7 +724,7 @@ func NextUntil(it Iterator, fn FnKeyCmp) error { func TestIterNextUntil(t *testing.T) { assert := assert.New(t) - buffer := newMemDB() + buffer := NewMemDB() insertData(t, buffer) iter, err := buffer.Iter(nil, nil) @@ -761,7 +739,7 @@ func TestIterNextUntil(t *testing.T) { func TestBasicNewIterator(t *testing.T) { assert := assert.New(t) - buffer := newMemDB() + buffer := NewMemDB() it, err := buffer.Iter([]byte("2"), nil) assert.Nil(err) assert.False(it.Valid()) @@ -780,7 +758,7 @@ func TestNewIteratorMin(t *testing.T) { {"DATA_test_main_db_tbl_tbl_test_record__00000000000000000002_0002", "2"}, {"DATA_test_main_db_tbl_tbl_test_record__00000000000000000002_0003", "hello"}, } - buffer := newMemDB() + buffer := NewMemDB() for _, kv := range kvs { err := buffer.Set([]byte(kv.key), []byte(kv.value)) assert.Nil(err) @@ -803,7 +781,7 @@ func TestNewIteratorMin(t *testing.T) { func TestMemDBStaging(t *testing.T) { assert := assert.New(t) - buffer := newMemDB() + buffer := NewMemDB() err := buffer.Set([]byte("x"), make([]byte, 2)) assert.Nil(err) @@ -815,25 +793,27 @@ func TestMemDBStaging(t *testing.T) { err = buffer.Set([]byte("yz"), make([]byte, 1)) assert.Nil(err) - v, _ := buffer.Get([]byte("x")) + v, _ := buffer.Get(context.Background(), []byte("x")) assert.Equal(len(v), 3) buffer.Release(h2) - v, _ = buffer.Get([]byte("yz")) + v, _ = buffer.Get(context.Background(), []byte("yz")) assert.Equal(len(v), 1) buffer.Cleanup(h1) - v, _ = buffer.Get([]byte("x")) + v, _ = buffer.Get(context.Background(), []byte("x")) assert.Equal(len(v), 2) } func TestBufferLimit(t *testing.T) { + testBufferLimit(t, newRbtDBWithContext()) +} + +func testBufferLimit(t *testing.T, buffer MemBuffer) { assert := assert.New(t) - buffer := newMemDB() - buffer.bufferSizeLimit = 1000 - buffer.entrySizeLimit = 500 + buffer.SetEntrySizeLimit(500, 1000) err := buffer.Set([]byte("x"), make([]byte, 500)) assert.NotNil(err) // entry size limit @@ -852,7 +832,7 @@ func TestBufferLimit(t *testing.T) { func TestUnsetTemporaryFlag(t *testing.T) { require := require.New(t) - db := newMemDB() + db := NewMemDB() key := []byte{1} value := []byte{2} db.SetWithFlags(key, value, kv.SetNeedConstraintCheckInPrewrite) @@ -864,7 +844,7 @@ func TestUnsetTemporaryFlag(t *testing.T) { func TestSnapshotGetIter(t *testing.T) { assert := assert.New(t) - buffer := newMemDB() + buffer := NewMemDB() var getters []Getter var iters []Iterator for i := 0; i < 100; i++ { diff --git a/internal/unionstore/mock.go b/internal/unionstore/mock.go index 6586975116..b946c384c8 100644 --- a/internal/unionstore/mock.go +++ b/internal/unionstore/mock.go @@ -44,18 +44,18 @@ type mockSnapshot struct { store *MemDB } -func (s *mockSnapshot) Get(_ context.Context, k []byte) ([]byte, error) { - return s.store.Get(k) +func (s *mockSnapshot) Get(ctx context.Context, k []byte) ([]byte, error) { + return s.store.Get(ctx, k) } func (s *mockSnapshot) SetPriority(priority int) { } -func (s *mockSnapshot) BatchGet(_ context.Context, keys [][]byte) (map[string][]byte, error) { +func (s *mockSnapshot) BatchGet(ctx context.Context, keys [][]byte) (map[string][]byte, error) { m := make(map[string][]byte, len(keys)) for _, k := range keys { - v, err := s.store.Get(k) + v, err := s.store.Get(ctx, k) if tikverr.IsErrNotFound(err) { continue } diff --git a/internal/unionstore/pipelined_memdb.go b/internal/unionstore/pipelined_memdb.go index 47518aed22..163a289f4c 100644 --- a/internal/unionstore/pipelined_memdb.go +++ b/internal/unionstore/pipelined_memdb.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/errors" tikverr "github.com/tikv/client-go/v2/error" "github.com/tikv/client-go/v2/internal/logutil" + "github.com/tikv/client-go/v2/internal/unionstore/arena" "github.com/tikv/client-go/v2/kv" "github.com/tikv/client-go/v2/metrics" "github.com/tikv/client-go/v2/util" @@ -35,7 +36,7 @@ import ( // - an immutable onflushing buffer for read // - like MemDB, PipelinedMemDB also CANNOT be used concurrently type PipelinedMemDB struct { - // Like MemDB, this RWMutex only used to ensure memdbSnapGetter.Get will not race with + // Like MemDB, this RWMutex only used to ensure rbtSnapGetter.Get will not race with // concurrent memdb.Set, memdb.SetWithFlags, memdb.Delete and memdb.UpdateFlags. sync.RWMutex onFlushing atomic.Bool @@ -102,8 +103,9 @@ type FlushFunc func(uint64, *MemDB) error type BufferBatchGetter func(ctx context.Context, keys [][]byte) (map[string][]byte, error) func NewPipelinedMemDB(bufferBatchGetter BufferBatchGetter, flushFunc FlushFunc) *PipelinedMemDB { - memdb := newMemDB() + memdb := NewMemDB() memdb.setSkipMutex(true) + entryLimit, _ := memdb.GetEntrySizeLimit() flushOpt := newFlushOption() return &PipelinedMemDB{ memDB: memdb, @@ -112,7 +114,7 @@ func NewPipelinedMemDB(bufferBatchGetter BufferBatchGetter, flushFunc FlushFunc) bufferBatchGetter: bufferBatchGetter, generation: 0, // keep entryLimit and bufferLimit same with the memdb's default values. - entryLimit: memdb.entrySizeLimit, + entryLimit: entryLimit, flushOption: flushOpt, startTime: time.Now(), } @@ -129,7 +131,7 @@ func (p *PipelinedMemDB) GetMemDB() *MemDB { } func (p *PipelinedMemDB) get(ctx context.Context, k []byte, skipRemoteBuffer bool) ([]byte, error) { - v, err := p.memDB.Get(k) + v, err := p.memDB.Get(ctx, k) if err == nil { return v, nil } @@ -137,7 +139,7 @@ func (p *PipelinedMemDB) get(ctx context.Context, k []byte, skipRemoteBuffer boo return nil, err } if p.flushingMemDB != nil { - v, err = p.flushingMemDB.Get(k) + v, err = p.flushingMemDB.Get(ctx, k) if err == nil { return v, nil } @@ -288,7 +290,7 @@ func (p *PipelinedMemDB) Flush(force bool) (bool, error) { // invalidate the batch get cache whether the flush is really triggered. p.batchGetCache = nil - if len(p.memDB.stages) > 0 { + if p.memDB.IsStaging() { return false, errors.New("there are stages unreleased when Flush is called") } @@ -309,9 +311,9 @@ func (p *PipelinedMemDB) Flush(force bool) (bool, error) { p.flushingMemDB = p.memDB p.len += p.flushingMemDB.Len() p.size += p.flushingMemDB.Size() - p.missCount += p.memDB.missCount.Load() - p.hitCount += p.memDB.hitCount.Load() - p.memDB = newMemDB() + p.missCount += p.memDB.GetCacheMissCount() + p.hitCount += p.memDB.GetCacheHitCount() + p.memDB = NewMemDB() // buffer size is limited by ForceFlushMemSizeThreshold. Do not set bufferLimit p.memDB.SetEntrySizeLimit(p.entryLimit, unlimitedSize) p.memDB.setSkipMutex(true) @@ -384,7 +386,7 @@ func (p *PipelinedMemDB) FlushWait() error { func (p *PipelinedMemDB) handleAlreadyExistErr(err error) error { var existErr *tikverr.ErrKeyExist if stderrors.As(err, &existErr) { - v, err2 := p.flushingMemDB.Get(existErr.GetKey()) + v, err2 := p.flushingMemDB.Get(context.Background(), existErr.GetKey()) if err2 != nil { // TODO: log more info like start_ts, also for other logs logutil.BgLogger().Warn( @@ -518,12 +520,12 @@ func (p *PipelinedMemDB) Release(h int) { } // Checkpoint implements MemBuffer interface. -func (p *PipelinedMemDB) Checkpoint() *MemDBCheckpoint { +func (p *PipelinedMemDB) Checkpoint() *arena.MemDBCheckpoint { panic("Checkpoint is not supported for PipelinedMemDB") } // RevertToCheckpoint implements MemBuffer interface. -func (p *PipelinedMemDB) RevertToCheckpoint(*MemDBCheckpoint) { +func (p *PipelinedMemDB) RevertToCheckpoint(*arena.MemDBCheckpoint) { panic("RevertToCheckpoint is not supported for PipelinedMemDB") } @@ -533,8 +535,8 @@ func (p *PipelinedMemDB) GetMetrics() Metrics { hitCount := p.hitCount missCount := p.missCount if p.memDB != nil { - hitCount += p.memDB.hitCount.Load() - missCount += p.memDB.missCount.Load() + hitCount += p.memDB.GetCacheHitCount() + missCount += p.memDB.GetCacheMissCount() } return Metrics{ WaitDuration: p.flushWaitDuration, diff --git a/internal/unionstore/pipelined_memdb_test.go b/internal/unionstore/pipelined_memdb_test.go index 4bbf6846e6..8b11228b2b 100644 --- a/internal/unionstore/pipelined_memdb_test.go +++ b/internal/unionstore/pipelined_memdb_test.go @@ -208,7 +208,7 @@ func TestPipelinedFlushGet(t *testing.T) { require.True(t, memdb.OnFlushing()) // The key is in flushingMemDB memdb instead of current mutable memdb. - _, err = memdb.memDB.Get([]byte("key")) + _, err = memdb.memDB.Get(context.Background(), []byte("key")) require.True(t, tikverr.IsErrNotFound(err)) // But we still can get the value by PipelinedMemDB.Get. value, err = memdb.Get(context.Background(), []byte("key")) diff --git a/internal/unionstore/rbt/rbt.go b/internal/unionstore/rbt/rbt.go new file mode 100644 index 0000000000..e2a8d81df6 --- /dev/null +++ b/internal/unionstore/rbt/rbt.go @@ -0,0 +1,926 @@ +// Copyright 2021 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// NOTE: The code in this file is based on code from the +// TiDB project, licensed under the Apache License v 2.0 +// +// https://github.com/pingcap/tidb/tree/cc5e161ac06827589c4966674597c137cc9e809c/store/tikv/unionstore/memdb.go +// + +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rbt + +import ( + "bytes" + "fmt" + "math" + "sync/atomic" + "unsafe" + + tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/internal/unionstore/arena" + "github.com/tikv/client-go/v2/kv" +) + +const unlimitedSize = math.MaxUint64 + +var testMode = false + +// RBT is rollbackable Red-Black Tree optimized for TiDB's transaction states buffer use scenario. +// You can think RBT is a combination of two separate tree map, one for key => value and another for key => keyFlags. +// +// The value map is rollbackable, that means you can use the `Staging`, `Release` and `Cleanup` API to safely modify KVs. +// +// The flags map is not rollbackable. There are two types of flag, persistent and non-persistent. +// When discarding a newly added KV in `Cleanup`, the non-persistent flags will be cleared. +// If there are persistent flags associated with key, we will keep this key in node without value. +type RBT struct { + root arena.MemdbArenaAddr + allocator nodeAllocator + vlog arena.MemdbVlog[*memdbNode, *RBT] + + entrySizeLimit uint64 + bufferSizeLimit uint64 + count int + size int + + vlogInvalid bool + dirty bool + stages []arena.MemDBCheckpoint + + // The lastTraversedNode must exist + lastTraversedNode atomic.Pointer[MemdbNodeAddr] + hitCount atomic.Uint64 + missCount atomic.Uint64 +} + +func New() *RBT { + db := new(RBT) + db.allocator.init() + db.root = arena.NullAddr + db.stages = make([]arena.MemDBCheckpoint, 0, 2) + db.entrySizeLimit = unlimitedSize + db.bufferSizeLimit = unlimitedSize + db.lastTraversedNode.Store(&nullNodeAddr) + return db +} + +// updateLastTraversed updates the last traversed node atomically +func (db *RBT) updateLastTraversed(node MemdbNodeAddr) { + db.lastTraversedNode.Store(&node) +} + +// checkKeyInCache retrieves the last traversed node if the key matches +func (db *RBT) checkKeyInCache(key []byte) (MemdbNodeAddr, bool) { + nodePtr := db.lastTraversedNode.Load() + if nodePtr == nil || nodePtr.isNull() { + return nullNodeAddr, false + } + + if bytes.Equal(key, nodePtr.memdbNode.getKey()) { + return *nodePtr, true + } + + return nullNodeAddr, false +} + +func (db *RBT) RevertNode(hdr *arena.MemdbVlogHdr) { + node := db.getNode(hdr.NodeAddr) + node.vptr = hdr.OldValue + db.size -= int(hdr.ValueLen) + // oldValue.isNull() == true means this is a newly added value. + if hdr.OldValue.IsNull() { + // If there are no flags associated with this key, we need to delete this node. + keptFlags := node.getKeyFlags().AndPersistent() + if keptFlags == 0 { + db.deleteNode(node) + } else { + node.setKeyFlags(keptFlags) + db.dirty = true + } + } else { + db.size += len(db.vlog.GetValue(hdr.OldValue)) + } +} + +func (db *RBT) InspectNode(addr arena.MemdbArenaAddr) (*memdbNode, arena.MemdbArenaAddr) { + node := db.allocator.getNode(addr) + return node, node.vptr +} + +// IsStaging returns whether the MemBuffer is in staging status. +func (db *RBT) IsStaging() bool { + return len(db.stages) > 0 +} + +// Staging create a new staging buffer inside the MemBuffer. +// Subsequent writes will be temporarily stored in this new staging buffer. +// When you think all modifications looks good, you can call `Release` to public all of them to the upper level buffer. +func (db *RBT) Staging() int { + db.stages = append(db.stages, db.vlog.Checkpoint()) + return len(db.stages) +} + +// Release publish all modifications in the latest staging buffer to upper level. +func (db *RBT) Release(h int) { + if h != len(db.stages) { + // This should never happens in production environment. + // Use panic to make debug easier. + panic("cannot release staging buffer") + } + + if h == 1 { + tail := db.vlog.Checkpoint() + if !db.stages[0].IsSamePosition(&tail) { + db.dirty = true + } + } + db.stages = db.stages[:h-1] +} + +// Cleanup cleanup the resources referenced by the StagingHandle. +// If the changes are not published by `Release`, they will be discarded. +func (db *RBT) Cleanup(h int) { + if h > len(db.stages) { + return + } + if h < len(db.stages) { + // This should never happens in production environment. + // Use panic to make debug easier. + panic(fmt.Sprintf("cannot cleanup staging buffer, h=%v, len(db.stages)=%v", h, len(db.stages))) + } + + cp := &db.stages[h-1] + if !db.vlogInvalid { + curr := db.vlog.Checkpoint() + if !curr.IsSamePosition(cp) { + db.vlog.RevertToCheckpoint(db, cp) + db.vlog.Truncate(cp) + } + } + db.stages = db.stages[:h-1] + db.vlog.OnMemChange() +} + +// Checkpoint returns a checkpoint of RBT. +func (db *RBT) Checkpoint() *arena.MemDBCheckpoint { + cp := db.vlog.Checkpoint() + return &cp +} + +// RevertToCheckpoint reverts the RBT to the checkpoint. +func (db *RBT) RevertToCheckpoint(cp *arena.MemDBCheckpoint) { + db.vlog.RevertToCheckpoint(db, cp) + db.vlog.Truncate(cp) + db.vlog.OnMemChange() +} + +// Reset resets the MemBuffer to initial states. +func (db *RBT) Reset() { + db.root = arena.NullAddr + db.stages = db.stages[:0] + db.dirty = false + db.vlogInvalid = false + db.size = 0 + db.count = 0 + db.vlog.Reset() + db.allocator.reset() +} + +// DiscardValues releases the memory used by all values. +// NOTE: any operation need value will panic after this function. +func (db *RBT) DiscardValues() { + db.vlogInvalid = true + db.vlog.Reset() +} + +// InspectStage used to inspect the value updates in the given stage. +func (db *RBT) InspectStage(handle int, f func([]byte, kv.KeyFlags, []byte)) { + idx := handle - 1 + tail := db.vlog.Checkpoint() + head := db.stages[idx] + db.vlog.InspectKVInLog(db, &head, &tail, f) +} + +// Get gets the value for key k from kv store. +// If corresponding kv pair does not exist, it returns nil and ErrNotExist. +func (db *RBT) Get(key []byte) ([]byte, error) { + if db.vlogInvalid { + // panic for easier debugging. + panic("vlog is resetted") + } + + x := db.traverse(key, false) + if x.isNull() { + return nil, tikverr.ErrNotExist + } + if x.vptr.IsNull() { + // A flag only key, act as value not exists + return nil, tikverr.ErrNotExist + } + return db.vlog.GetValue(x.vptr), nil +} + +// SelectValueHistory select the latest value which makes `predicate` returns true from the modification history. +func (db *RBT) SelectValueHistory(key []byte, predicate func(value []byte) bool) ([]byte, error) { + x := db.traverse(key, false) + if x.isNull() { + return nil, tikverr.ErrNotExist + } + if x.vptr.IsNull() { + // A flag only key, act as value not exists + return nil, tikverr.ErrNotExist + } + result := db.vlog.SelectValueHistory(x.vptr, func(addr arena.MemdbArenaAddr) bool { + return predicate(db.vlog.GetValue(addr)) + }) + if result.IsNull() { + return nil, nil + } + return db.vlog.GetValue(result), nil +} + +// GetFlags returns the latest flags associated with key. +func (db *RBT) GetFlags(key []byte) (kv.KeyFlags, error) { + x := db.traverse(key, false) + if x.isNull() { + return 0, tikverr.ErrNotExist + } + return x.getKeyFlags(), nil +} + +// GetKeyByHandle returns key by handle. +func (db *RBT) GetKeyByHandle(handle arena.MemKeyHandle) []byte { + x := db.getNode(handle.ToAddr()) + return x.getKey() +} + +// GetValueByHandle returns value by handle. +func (db *RBT) GetValueByHandle(handle arena.MemKeyHandle) ([]byte, bool) { + if db.vlogInvalid { + return nil, false + } + x := db.getNode(handle.ToAddr()) + if x.vptr.IsNull() { + return nil, false + } + return db.vlog.GetValue(x.vptr), true +} + +// Len returns the number of entries in the DB. +func (db *RBT) Len() int { + return db.count +} + +// Size returns sum of keys and values length. +func (db *RBT) Size() int { + return db.size +} + +// Dirty returns whether the root staging buffer is updated. +func (db *RBT) Dirty() bool { + return db.dirty +} + +func (db *RBT) Set(key []byte, value []byte, ops ...kv.FlagsOp) error { + if db.vlogInvalid { + // panic for easier debugging. + panic("vlog is reset") + } + + if value != nil { + if size := uint64(len(key) + len(value)); size > db.entrySizeLimit { + return &tikverr.ErrEntryTooLarge{ + Limit: db.entrySizeLimit, + Size: size, + } + } + } + + if len(db.stages) == 0 { + db.dirty = true + } + x := db.traverse(key, true) + + // the NeedConstraintCheckInPrewrite flag is temporary, + // every write to the node removes the flag unless it's explicitly set. + // This set must be in the latest stage so no special processing is needed. + var flags kv.KeyFlags + if value != nil { + flags = kv.ApplyFlagsOps(x.getKeyFlags(), append([]kv.FlagsOp{kv.DelNeedConstraintCheckInPrewrite}, ops...)...) + } else { + // an UpdateFlag operation, do not delete the NeedConstraintCheckInPrewrite flag. + flags = kv.ApplyFlagsOps(x.getKeyFlags(), ops...) + } + if flags.AndPersistent() != 0 { + db.dirty = true + } + x.setKeyFlags(flags) + + if value == nil { + return nil + } + + db.setValue(x, value) + if uint64(db.Size()) > db.bufferSizeLimit { + return &tikverr.ErrTxnTooLarge{Size: db.Size()} + } + return nil +} + +func (db *RBT) setValue(x MemdbNodeAddr, value []byte) { + var activeCp *arena.MemDBCheckpoint + if len(db.stages) > 0 { + activeCp = &db.stages[len(db.stages)-1] + } + + var oldVal []byte + if !x.vptr.IsNull() { + oldVal = db.vlog.GetValue(x.vptr) + } + + if len(oldVal) > 0 && db.vlog.CanModify(activeCp, x.vptr) { + // For easier to implement, we only consider this case. + // It is the most common usage in TiDB's transaction buffers. + if len(oldVal) == len(value) { + copy(oldVal, value) + return + } + } + x.vptr = db.vlog.AppendValue(x.addr, x.vptr, value) + db.size = db.size - len(oldVal) + len(value) +} + +// traverse search for and if not found and insert is true, will add a new node in. +// Returns a pointer to the new node, or the node found. +func (db *RBT) traverse(key []byte, insert bool) MemdbNodeAddr { + if node, found := db.checkKeyInCache(key); found { + db.hitCount.Add(1) + return node + } + db.missCount.Add(1) + + x := db.getRoot() + y := MemdbNodeAddr{nil, arena.NullAddr} + found := false + + // walk x down the tree + for !x.isNull() && !found { + cmp := bytes.Compare(key, x.getKey()) + if cmp < 0 { + if insert && x.left.IsNull() { + y = x + } + x = x.getLeft(db) + } else if cmp > 0 { + if insert && x.right.IsNull() { + y = x + } + x = x.getRight(db) + } else { + found = true + } + } + + if found { + db.updateLastTraversed(x) + } + + if found || !insert { + return x + } + + z := db.allocNode(key) + z.up = y.addr + + if y.isNull() { + db.root = z.addr + } else { + cmp := bytes.Compare(z.getKey(), y.getKey()) + if cmp < 0 { + y.left = z.addr + } else { + y.right = z.addr + } + } + + z.left = arena.NullAddr + z.right = arena.NullAddr + + // colour this new node red + z.setRed() + + // Having added a red node, we must now walk back up the tree balancing it, + // by a series of rotations and changing of colours + x = z + + // While we are not at the top and our parent node is red + // NOTE: Since the root node is guaranteed black, then we + // are also going to stop if we are the child of the root + + for x.addr != db.root { + xUp := x.getUp(db) + if xUp.isBlack() { + break + } + + xUpUp := xUp.getUp(db) + // if our parent is on the left side of our grandparent + if x.up == xUpUp.left { + // get the right side of our grandparent (uncle?) + y = xUpUp.getRight(db) + if y.isRed() { + // make our parent black + xUp.setBlack() + // make our uncle black + y.setBlack() + // make our grandparent red + xUpUp.setRed() + // now consider our grandparent + x = xUp.getUp(db) + } else { + // if we are on the right side of our parent + if x.addr == xUp.right { + // Move up to our parent + x = x.getUp(db) + db.leftRotate(x) + xUp = x.getUp(db) + xUpUp = xUp.getUp(db) + } + + xUp.setBlack() + xUpUp.setRed() + db.rightRotate(xUpUp) + } + } else { + // everything here is the same as above, but exchanging left for right + y = xUpUp.getLeft(db) + if y.isRed() { + xUp.setBlack() + y.setBlack() + xUpUp.setRed() + + x = xUp.getUp(db) + } else { + if x.addr == xUp.left { + x = x.getUp(db) + db.rightRotate(x) + xUp = x.getUp(db) + xUpUp = xUp.getUp(db) + } + + xUp.setBlack() + xUpUp.setRed() + db.leftRotate(xUpUp) + } + } + } + + // Set the root node black + db.getRoot().setBlack() + + db.updateLastTraversed(z) + + return z +} + +// +// Rotate our tree thus:- +// +// X leftRotate(X)---> Y +// / \ / \ +// A Y <---rightRotate(Y) X C +// / \ / \ +// B C A B +// +// NOTE: This does not change the ordering. +// +// We assume that neither X nor Y is NULL +// + +func (db *RBT) leftRotate(x MemdbNodeAddr) { + y := x.getRight(db) + + // Turn Y's left subtree into X's right subtree (move B) + x.right = y.left + + // If B is not null, set it's parent to be X + if !y.left.IsNull() { + left := y.getLeft(db) + left.up = x.addr + } + + // Set Y's parent to be what X's parent was + y.up = x.up + + // if X was the root + if x.up.IsNull() { + db.root = y.addr + } else { + xUp := x.getUp(db) + // Set X's parent's left or right pointer to be Y + if x.addr == xUp.left { + xUp.left = y.addr + } else { + xUp.right = y.addr + } + } + + // Put X on Y's left + y.left = x.addr + // Set X's parent to be Y + x.up = y.addr +} + +func (db *RBT) rightRotate(y MemdbNodeAddr) { + x := y.getLeft(db) + + // Turn X's right subtree into Y's left subtree (move B) + y.left = x.right + + // If B is not null, set it's parent to be Y + if !x.right.IsNull() { + right := x.getRight(db) + right.up = y.addr + } + + // Set X's parent to be what Y's parent was + x.up = y.up + + // if Y was the root + if y.up.IsNull() { + db.root = x.addr + } else { + yUp := y.getUp(db) + // Set Y's parent's left or right pointer to be X + if y.addr == yUp.left { + yUp.left = x.addr + } else { + yUp.right = x.addr + } + } + + // Put Y on X's right + x.right = y.addr + // Set Y's parent to be X + y.up = x.addr +} + +func (db *RBT) deleteNode(z MemdbNodeAddr) { + var x, y MemdbNodeAddr + if db.lastTraversedNode.Load().addr == z.addr { + db.lastTraversedNode.Store(&nullNodeAddr) + } + + db.count-- + db.size -= int(z.klen) + + if z.left.IsNull() || z.right.IsNull() { + y = z + } else { + y = db.successor(z) + } + + if !y.left.IsNull() { + x = y.getLeft(db) + } else { + x = y.getRight(db) + } + x.up = y.up + + if y.up.IsNull() { + db.root = x.addr + } else { + yUp := y.getUp(db) + if y.addr == yUp.left { + yUp.left = x.addr + } else { + yUp.right = x.addr + } + } + + needFix := y.isBlack() + + // NOTE: traditional red-black tree will copy key from Y to Z and free Y. + // We cannot do the same thing here, due to Y's pointer is stored in vlog and the space in Z may not suitable for Y. + // So we need to copy states from Z to Y, and relink all nodes formerly connected to Z. + if y != z { + db.replaceNode(z, y) + } + + if needFix { + db.deleteNodeFix(x) + } + + db.allocator.freeNode(z.addr) +} + +func (db *RBT) replaceNode(old MemdbNodeAddr, new MemdbNodeAddr) { + if !old.up.IsNull() { + oldUp := old.getUp(db) + if old.addr == oldUp.left { + oldUp.left = new.addr + } else { + oldUp.right = new.addr + } + } else { + db.root = new.addr + } + new.up = old.up + + left := old.getLeft(db) + left.up = new.addr + new.left = old.left + + right := old.getRight(db) + right.up = new.addr + new.right = old.right + + if old.isBlack() { + new.setBlack() + } else { + new.setRed() + } +} + +func (db *RBT) deleteNodeFix(x MemdbNodeAddr) { + for x.addr != db.root && x.isBlack() { + xUp := x.getUp(db) + if x.addr == xUp.left { + w := xUp.getRight(db) + if w.isRed() { + w.setBlack() + xUp.setRed() + db.leftRotate(xUp) + w = x.getUp(db).getRight(db) + } + + if w.getLeft(db).isBlack() && w.getRight(db).isBlack() { + w.setRed() + x = x.getUp(db) + } else { + if w.getRight(db).isBlack() { + w.getLeft(db).setBlack() + w.setRed() + db.rightRotate(w) + w = x.getUp(db).getRight(db) + } + + xUp := x.getUp(db) + if xUp.isBlack() { + w.setBlack() + } else { + w.setRed() + } + xUp.setBlack() + w.getRight(db).setBlack() + db.leftRotate(xUp) + x = db.getRoot() + } + } else { + w := xUp.getLeft(db) + if w.isRed() { + w.setBlack() + xUp.setRed() + db.rightRotate(xUp) + w = x.getUp(db).getLeft(db) + } + + if w.getRight(db).isBlack() && w.getLeft(db).isBlack() { + w.setRed() + x = x.getUp(db) + } else { + if w.getLeft(db).isBlack() { + w.getRight(db).setBlack() + w.setRed() + db.leftRotate(w) + w = x.getUp(db).getLeft(db) + } + + xUp := x.getUp(db) + if xUp.isBlack() { + w.setBlack() + } else { + w.setRed() + } + xUp.setBlack() + w.getLeft(db).setBlack() + db.rightRotate(xUp) + x = db.getRoot() + } + } + } + x.setBlack() +} + +func (db *RBT) successor(x MemdbNodeAddr) (y MemdbNodeAddr) { + if !x.right.IsNull() { + // If right is not NULL then go right one and + // then keep going left until we find a node with + // no left pointer. + + y = x.getRight(db) + for !y.left.IsNull() { + y = y.getLeft(db) + } + return + } + + // Go up the tree until we get to a node that is on the + // left of its parent (or the root) and then return the + // parent. + + y = x.getUp(db) + for !y.isNull() && x.addr == y.right { + x = y + y = y.getUp(db) + } + return y +} + +func (db *RBT) predecessor(x MemdbNodeAddr) (y MemdbNodeAddr) { + if !x.left.IsNull() { + // If left is not NULL then go left one and + // then keep going right until we find a node with + // no right pointer. + + y = x.getLeft(db) + for !y.right.IsNull() { + y = y.getRight(db) + } + return + } + + // Go up the tree until we get to a node that is on the + // right of its parent (or the root) and then return the + // parent. + + y = x.getUp(db) + for !y.isNull() && x.addr == y.left { + x = y + y = y.getUp(db) + } + return y +} + +func (db *RBT) getNode(x arena.MemdbArenaAddr) MemdbNodeAddr { + return MemdbNodeAddr{db.allocator.getNode(x), x} +} + +func (db *RBT) getRoot() MemdbNodeAddr { + return db.getNode(db.root) +} + +func (db *RBT) allocNode(key []byte) MemdbNodeAddr { + db.size += len(key) + db.count++ + x, xn := db.allocator.allocNode(key) + return MemdbNodeAddr{xn, x} +} + +var nullNodeAddr = MemdbNodeAddr{nil, arena.NullAddr} + +type MemdbNodeAddr struct { + *memdbNode + addr arena.MemdbArenaAddr +} + +func (a *MemdbNodeAddr) isNull() bool { + return a.addr.IsNull() +} + +func (a MemdbNodeAddr) getUp(db *RBT) MemdbNodeAddr { + return db.getNode(a.up) +} + +func (a MemdbNodeAddr) getLeft(db *RBT) MemdbNodeAddr { + return db.getNode(a.left) +} + +func (a MemdbNodeAddr) getRight(db *RBT) MemdbNodeAddr { + return db.getNode(a.right) +} + +type memdbNode struct { + up arena.MemdbArenaAddr + left arena.MemdbArenaAddr + right arena.MemdbArenaAddr + vptr arena.MemdbArenaAddr + klen uint16 + flags uint16 +} + +func (n *memdbNode) isRed() bool { + return n.flags&nodeColorBit != 0 +} + +func (n *memdbNode) isBlack() bool { + return !n.isRed() +} + +func (n *memdbNode) setRed() { + n.flags |= nodeColorBit +} + +func (n *memdbNode) setBlack() { + n.flags &= ^nodeColorBit +} + +func (n *memdbNode) GetKey() []byte { + return n.getKey() +} + +func (n *memdbNode) getKey() []byte { + base := unsafe.Add(unsafe.Pointer(&n.flags), kv.FlagBytes) + return unsafe.Slice((*byte)(base), int(n.klen)) +} + +const ( + // bit 1 => red, bit 0 => black + nodeColorBit uint16 = 0x8000 + nodeFlagsMask = ^nodeColorBit +) + +func (n *memdbNode) GetKeyFlags() kv.KeyFlags { + return n.getKeyFlags() +} + +func (n *memdbNode) getKeyFlags() kv.KeyFlags { + return kv.KeyFlags(n.flags & nodeFlagsMask) +} + +func (n *memdbNode) setKeyFlags(f kv.KeyFlags) { + n.flags = (^nodeFlagsMask & n.flags) | uint16(f) +} + +// RemoveFromBuffer removes a record from the mem buffer. It should be only used for test. +func (db *RBT) RemoveFromBuffer(key []byte) { + x := db.traverse(key, false) + if x.isNull() { + return + } + db.size -= len(db.vlog.GetValue(x.vptr)) + db.deleteNode(x) +} + +// SetMemoryFootprintChangeHook sets the hook function that is triggered when memdb grows. +func (db *RBT) SetMemoryFootprintChangeHook(hook func(uint64)) { + innerHook := func() { + hook(db.allocator.Capacity() + db.vlog.Capacity()) + } + db.allocator.SetMemChangeHook(innerHook) + db.vlog.SetMemChangeHook(innerHook) +} + +// Mem returns the current memory footprint +func (db *RBT) Mem() uint64 { + return db.allocator.Capacity() + db.vlog.Capacity() +} + +// GetEntrySizeLimit gets the size limit for each entry and total buffer. +func (db *RBT) GetEntrySizeLimit() (uint64, uint64) { + return db.entrySizeLimit, db.bufferSizeLimit +} + +// SetEntrySizeLimit sets the size limit for each entry and total buffer. +func (db *RBT) SetEntrySizeLimit(entryLimit, bufferLimit uint64) { + db.entrySizeLimit = entryLimit + db.bufferSizeLimit = bufferLimit +} + +// MemHookSet implements the MemBuffer interface. +func (db *RBT) MemHookSet() bool { + return db.allocator.MemHookSet() +} + +func (db *RBT) GetCacheHitCount() uint64 { + return db.hitCount.Load() +} + +func (db *RBT) GetCacheMissCount() uint64 { + return db.missCount.Load() +} diff --git a/internal/unionstore/rbt/rbt_arena.go b/internal/unionstore/rbt/rbt_arena.go new file mode 100644 index 0000000000..67ba62e1da --- /dev/null +++ b/internal/unionstore/rbt/rbt_arena.go @@ -0,0 +1,101 @@ +// Copyright 2021 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// NOTE: The code in this file is based on code from the +// TiDB project, licensed under the Apache License v 2.0 +// +// https://github.com/pingcap/tidb/tree/cc5e161ac06827589c4966674597c137cc9e809c/store/tikv/unionstore/memdb_arena.go +// + +// Copyright 2020 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rbt + +import ( + "unsafe" + + "github.com/tikv/client-go/v2/internal/unionstore/arena" +) + +type nodeAllocator struct { + arena.MemdbArena + + // Dummy node, so that we can make X.left.up = X. + // We then use this instead of NULL to mean the top or bottom + // end of the rb tree. It is a black node. + nullNode memdbNode +} + +func (a *nodeAllocator) init() { + a.nullNode = memdbNode{ + up: arena.NullAddr, + left: arena.NullAddr, + right: arena.NullAddr, + vptr: arena.NullAddr, + } +} + +func (a *nodeAllocator) getNode(addr arena.MemdbArenaAddr) *memdbNode { + if addr.IsNull() { + return &a.nullNode + } + data := a.GetData(addr) + return (*memdbNode)(unsafe.Pointer(&data[0])) +} + +const memdbNodeSize = int(unsafe.Sizeof(memdbNode{})) + +func (a *nodeAllocator) allocNode(key []byte) (arena.MemdbArenaAddr, *memdbNode) { + nodeSize := memdbNodeSize + len(key) + prevBlocks := a.Blocks() + addr, mem := a.Alloc(nodeSize, true) + n := (*memdbNode)(unsafe.Pointer(&mem[0])) + n.vptr = arena.NullAddr + n.klen = uint16(len(key)) + copy(n.getKey(), key) + if prevBlocks != a.Blocks() { + a.OnMemChange() + } + return addr, n +} + +func (a *nodeAllocator) freeNode(addr arena.MemdbArenaAddr) { + if testMode { + // Make it easier for debug. + n := a.getNode(addr) + n.left = arena.BadAddr + n.right = arena.BadAddr + n.up = arena.BadAddr + n.vptr = arena.BadAddr + return + } + // TODO: reuse freed nodes. Need to fix lastTraversedNode when implementing this. +} + +func (a *nodeAllocator) reset() { + a.MemdbArena.Reset() + a.init() +} diff --git a/internal/unionstore/memdb_iterator.go b/internal/unionstore/rbt/rbt_iterator.go similarity index 75% rename from internal/unionstore/memdb_iterator.go rename to internal/unionstore/rbt/rbt_iterator.go index 3b4bdfd8f0..dea270b4f8 100644 --- a/internal/unionstore/memdb_iterator.go +++ b/internal/unionstore/rbt/rbt_iterator.go @@ -32,18 +32,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -package unionstore +package rbt import ( "bytes" + "github.com/tikv/client-go/v2/internal/unionstore/arena" "github.com/tikv/client-go/v2/kv" ) -// MemdbIterator is an Iterator with KeyFlags related functions. -type MemdbIterator struct { - db *MemDB - curr memdbNodeAddr +// RBTIterator is an Iterator with KeyFlags related functions. +type RBTIterator struct { + db *RBT + curr MemdbNodeAddr start []byte end []byte reverse bool @@ -54,8 +55,8 @@ type MemdbIterator struct { // If such entry is not found, it returns an invalid Iterator with no error. // It yields only keys that < upperBound. If upperBound is nil, it means the upperBound is unbounded. // The Iterator must be Closed after use. -func (db *MemDB) Iter(k []byte, upperBound []byte) (Iterator, error) { - i := &MemdbIterator{ +func (db *RBT) Iter(k []byte, upperBound []byte) (*RBTIterator, error) { + i := &RBTIterator{ db: db, start: k, end: upperBound, @@ -68,8 +69,8 @@ func (db *MemDB) Iter(k []byte, upperBound []byte) (Iterator, error) { // The returned iterator will iterate from greater key to smaller key. // If k is nil, the returned iterator will be positioned at the last key. // It yields only keys that >= lowerBound. If lowerBound is nil, it means the lowerBound is unbounded. -func (db *MemDB) IterReverse(k []byte, lowerBound []byte) (Iterator, error) { - i := &MemdbIterator{ +func (db *RBT) IterReverse(k []byte, lowerBound []byte) (*RBTIterator, error) { + i := &RBTIterator{ db: db, start: lowerBound, end: k, @@ -79,9 +80,9 @@ func (db *MemDB) IterReverse(k []byte, lowerBound []byte) (Iterator, error) { return i, nil } -// IterWithFlags returns a MemdbIterator. -func (db *MemDB) IterWithFlags(k []byte, upperBound []byte) *MemdbIterator { - i := &MemdbIterator{ +// IterWithFlags returns a RBTIterator. +func (db *RBT) IterWithFlags(k []byte, upperBound []byte) *RBTIterator { + i := &RBTIterator{ db: db, start: k, end: upperBound, @@ -91,9 +92,9 @@ func (db *MemDB) IterWithFlags(k []byte, upperBound []byte) *MemdbIterator { return i } -// IterReverseWithFlags returns a reversed MemdbIterator. -func (db *MemDB) IterReverseWithFlags(k []byte) *MemdbIterator { - i := &MemdbIterator{ +// IterReverseWithFlags returns a reversed RBTIterator. +func (db *RBT) IterReverseWithFlags(k []byte) *RBTIterator { + i := &RBTIterator{ db: db, end: k, reverse: true, @@ -103,7 +104,7 @@ func (db *MemDB) IterReverseWithFlags(k []byte) *MemdbIterator { return i } -func (i *MemdbIterator) init() { +func (i *RBTIterator) init() { if i.reverse { if len(i.end) == 0 { i.seekToLast() @@ -125,7 +126,7 @@ func (i *MemdbIterator) init() { } // Valid returns true if the current iterator is valid. -func (i *MemdbIterator) Valid() bool { +func (i *RBTIterator) Valid() bool { if !i.reverse { return !i.curr.isNull() && (i.end == nil || bytes.Compare(i.Key(), i.end) < 0) } @@ -133,42 +134,39 @@ func (i *MemdbIterator) Valid() bool { } // Flags returns flags belong to current iterator. -func (i *MemdbIterator) Flags() kv.KeyFlags { +func (i *RBTIterator) Flags() kv.KeyFlags { return i.curr.getKeyFlags() } // UpdateFlags updates and apply with flagsOp. -func (i *MemdbIterator) UpdateFlags(ops ...kv.FlagsOp) { +func (i *RBTIterator) UpdateFlags(ops ...kv.FlagsOp) { origin := i.curr.getKeyFlags() n := kv.ApplyFlagsOps(origin, ops...) i.curr.setKeyFlags(n) } // HasValue returns false if it is flags only. -func (i *MemdbIterator) HasValue() bool { +func (i *RBTIterator) HasValue() bool { return !i.isFlagsOnly() } // Key returns current key. -func (i *MemdbIterator) Key() []byte { +func (i *RBTIterator) Key() []byte { return i.curr.getKey() } // Handle returns MemKeyHandle with the current position. -func (i *MemdbIterator) Handle() MemKeyHandle { - return MemKeyHandle{ - idx: uint16(i.curr.addr.idx), - off: i.curr.addr.off, - } +func (i *RBTIterator) Handle() arena.MemKeyHandle { + return i.curr.addr.ToHandle() } // Value returns the value. -func (i *MemdbIterator) Value() []byte { - return i.db.vlog.getValue(i.curr.vptr) +func (i *RBTIterator) Value() []byte { + return i.db.vlog.GetValue(i.curr.vptr) } // Next goes the next position. -func (i *MemdbIterator) Next() error { +func (i *RBTIterator) Next() error { for { if i.reverse { i.curr = i.db.predecessor(i.curr) @@ -185,10 +183,10 @@ func (i *MemdbIterator) Next() error { } // Close closes the current iterator. -func (i *MemdbIterator) Close() {} +func (i *RBTIterator) Close() {} -func (i *MemdbIterator) seekToFirst() { - y := memdbNodeAddr{nil, nullAddr} +func (i *RBTIterator) seekToFirst() { + y := MemdbNodeAddr{nil, arena.NullAddr} x := i.db.getNode(i.db.root) for !x.isNull() { @@ -199,8 +197,8 @@ func (i *MemdbIterator) seekToFirst() { i.curr = y } -func (i *MemdbIterator) seekToLast() { - y := memdbNodeAddr{nil, nullAddr} +func (i *RBTIterator) seekToLast() { + y := MemdbNodeAddr{nil, arena.NullAddr} x := i.db.getNode(i.db.root) for !x.isNull() { @@ -211,8 +209,8 @@ func (i *MemdbIterator) seekToLast() { i.curr = y } -func (i *MemdbIterator) seek(key []byte) { - y := memdbNodeAddr{nil, nullAddr} +func (i *RBTIterator) seek(key []byte) { + y := MemdbNodeAddr{nil, arena.NullAddr} x := i.db.getNode(i.db.root) var cmp int @@ -246,6 +244,6 @@ func (i *MemdbIterator) seek(key []byte) { i.curr = y } -func (i *MemdbIterator) isFlagsOnly() bool { - return !i.curr.isNull() && i.curr.vptr.isNull() +func (i *RBTIterator) isFlagsOnly() bool { + return !i.curr.isNull() && i.curr.vptr.IsNull() } diff --git a/internal/unionstore/memdb_snapshot.go b/internal/unionstore/rbt/rbt_snapshot.go similarity index 71% rename from internal/unionstore/memdb_snapshot.go rename to internal/unionstore/rbt/rbt_snapshot.go index c1aec9541a..0432ec22d8 100644 --- a/internal/unionstore/memdb_snapshot.go +++ b/internal/unionstore/rbt/rbt_snapshot.go @@ -32,26 +32,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -package unionstore +package rbt import ( "context" tikverr "github.com/tikv/client-go/v2/error" + "github.com/tikv/client-go/v2/internal/unionstore/arena" ) // SnapshotGetter returns a Getter for a snapshot of MemBuffer. -func (db *MemDB) SnapshotGetter() Getter { - return &memdbSnapGetter{ +func (db *RBT) SnapshotGetter() *rbtSnapGetter { + return &rbtSnapGetter{ db: db, cp: db.getSnapshot(), } } -// SnapshotIter returns a Iterator for a snapshot of MemBuffer. -func (db *MemDB) SnapshotIter(start, end []byte) Iterator { - it := &memdbSnapIter{ - MemdbIterator: &MemdbIterator{ +// SnapshotIter returns an Iterator for a snapshot of MemBuffer. +func (db *RBT) SnapshotIter(start, end []byte) *rbtSnapIter { + it := &rbtSnapIter{ + RBTIterator: &RBTIterator{ db: db, start: start, end: end, @@ -63,9 +64,9 @@ func (db *MemDB) SnapshotIter(start, end []byte) Iterator { } // SnapshotIterReverse returns a reverse Iterator for a snapshot of MemBuffer. -func (db *MemDB) SnapshotIterReverse(k, lowerBound []byte) Iterator { - it := &memdbSnapIter{ - MemdbIterator: &MemdbIterator{ +func (db *RBT) SnapshotIterReverse(k, lowerBound []byte) *rbtSnapIter { + it := &rbtSnapIter{ + RBTIterator: &RBTIterator{ db: db, start: lowerBound, end: k, @@ -77,48 +78,48 @@ func (db *MemDB) SnapshotIterReverse(k, lowerBound []byte) Iterator { return it } -func (db *MemDB) getSnapshot() MemDBCheckpoint { +func (db *RBT) getSnapshot() arena.MemDBCheckpoint { if len(db.stages) > 0 { return db.stages[0] } - return db.vlog.checkpoint() + return db.vlog.Checkpoint() } -type memdbSnapGetter struct { - db *MemDB - cp MemDBCheckpoint +type rbtSnapGetter struct { + db *RBT + cp arena.MemDBCheckpoint } -func (snap *memdbSnapGetter) Get(ctx context.Context, key []byte) ([]byte, error) { +func (snap *rbtSnapGetter) Get(ctx context.Context, key []byte) ([]byte, error) { x := snap.db.traverse(key, false) if x.isNull() { return nil, tikverr.ErrNotExist } - if x.vptr.isNull() { + if x.vptr.IsNull() { // A flag only key, act as value not exists return nil, tikverr.ErrNotExist } - v, ok := snap.db.vlog.getSnapshotValue(x.vptr, &snap.cp) + v, ok := snap.db.vlog.GetSnapshotValue(x.vptr, &snap.cp) if !ok { return nil, tikverr.ErrNotExist } return v, nil } -type memdbSnapIter struct { - *MemdbIterator +type rbtSnapIter struct { + *RBTIterator value []byte - cp MemDBCheckpoint + cp arena.MemDBCheckpoint } -func (i *memdbSnapIter) Value() []byte { +func (i *rbtSnapIter) Value() []byte { return i.value } -func (i *memdbSnapIter) Next() error { +func (i *rbtSnapIter) Next() error { i.value = nil for i.Valid() { - if err := i.MemdbIterator.Next(); err != nil { + if err := i.RBTIterator.Next(); err != nil { return err } if i.setValue() { @@ -128,18 +129,18 @@ func (i *memdbSnapIter) Next() error { return nil } -func (i *memdbSnapIter) setValue() bool { +func (i *rbtSnapIter) setValue() bool { if !i.Valid() { return false } - if v, ok := i.db.vlog.getSnapshotValue(i.curr.vptr, &i.cp); ok { + if v, ok := i.db.vlog.GetSnapshotValue(i.curr.vptr, &i.cp); ok { i.value = v return true } return false } -func (i *memdbSnapIter) init() { +func (i *rbtSnapIter) init() { if i.reverse { if len(i.end) == 0 { i.seekToLast() diff --git a/internal/unionstore/rbt/rbt_test.go b/internal/unionstore/rbt/rbt_test.go new file mode 100644 index 0000000000..8142adaf02 --- /dev/null +++ b/internal/unionstore/rbt/rbt_test.go @@ -0,0 +1,170 @@ +// Copyright 2024 TiKV Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rbt + +import ( + "encoding/binary" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tikv/client-go/v2/kv" +) + +func init() { + testMode = true +} + +func deriveAndFill(start, end, valueBase int, db *RBT) int { + h := db.Staging() + var kbuf, vbuf [4]byte + for i := start; i < end; i++ { + binary.BigEndian.PutUint32(kbuf[:], uint32(i)) + binary.BigEndian.PutUint32(vbuf[:], uint32(i+valueBase)) + db.Set(kbuf[:], vbuf[:]) + } + return h +} + +func TestDiscard(t *testing.T) { + assert := assert.New(t) + + const cnt = 10000 + db := New() + base := deriveAndFill(0, cnt, 0, db) + sz := db.Size() + + db.Cleanup(deriveAndFill(0, cnt, 1, db)) + assert.Equal(db.Len(), cnt) + assert.Equal(db.Size(), sz) + + var buf [4]byte + + for i := 0; i < cnt; i++ { + binary.BigEndian.PutUint32(buf[:], uint32(i)) + v, err := db.Get(buf[:]) + assert.Nil(err) + assert.Equal(v, buf[:]) + } + + var i int + for it, _ := db.Iter(nil, nil); it.Valid(); it.Next() { + binary.BigEndian.PutUint32(buf[:], uint32(i)) + assert.Equal(it.Key(), buf[:]) + assert.Equal(it.Value(), buf[:]) + i++ + } + assert.Equal(i, cnt) + + i-- + for it, _ := db.IterReverse(nil, nil); it.Valid(); it.Next() { + binary.BigEndian.PutUint32(buf[:], uint32(i)) + assert.Equal(it.Key(), buf[:]) + assert.Equal(it.Value(), buf[:]) + i-- + } + assert.Equal(i, -1) + + db.Cleanup(base) + for i := 0; i < cnt; i++ { + binary.BigEndian.PutUint32(buf[:], uint32(i)) + _, err := db.Get(buf[:]) + assert.NotNil(err) + } + it, _ := db.Iter(nil, nil) + it.seekToFirst() + assert.False(it.Valid()) + it.seekToLast() + assert.False(it.Valid()) + it.seek([]byte{0xff}) + assert.False(it.Valid()) +} + +func TestEmptyDB(t *testing.T) { + assert := assert.New(t) + db := New() + _, err := db.Get([]byte{0}) + assert.NotNil(err) + it, _ := db.Iter(nil, nil) + it.seekToFirst() + assert.False(it.Valid()) + it.seekToLast() + assert.False(it.Valid()) + it.seek([]byte{0xff}) + assert.False(it.Valid()) +} + +func TestFlags(t *testing.T) { + assert := assert.New(t) + + const cnt = 10000 + db := New() + h := db.Staging() + for i := uint32(0); i < cnt; i++ { + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], i) + if i%2 == 0 { + db.Set(buf[:], buf[:], kv.SetPresumeKeyNotExists, kv.SetKeyLocked) + } else { + db.Set(buf[:], buf[:], kv.SetPresumeKeyNotExists) + } + } + db.Cleanup(h) + + for i := uint32(0); i < cnt; i++ { + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], i) + _, err := db.Get(buf[:]) + assert.NotNil(err) + flags, err := db.GetFlags(buf[:]) + if i%2 == 0 { + assert.Nil(err) + assert.True(flags.HasLocked()) + assert.False(flags.HasPresumeKeyNotExists()) + } else { + assert.NotNil(err) + } + } + + assert.Equal(db.Len(), 5000) + assert.Equal(db.Size(), 20000) + + it, _ := db.Iter(nil, nil) + assert.False(it.Valid()) + + it.includeFlags = true + it.init() + + for ; it.Valid(); it.Next() { + k := binary.BigEndian.Uint32(it.Key()) + assert.True(k%2 == 0) + } + + for i := uint32(0); i < cnt; i++ { + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], i) + db.Set(buf[:], nil, kv.DelKeyLocked) + } + for i := uint32(0); i < cnt; i++ { + var buf [4]byte + binary.BigEndian.PutUint32(buf[:], i) + _, err := db.Get(buf[:]) + assert.NotNil(err) + + // UpdateFlags will create missing node. + flags, err := db.GetFlags(buf[:]) + assert.Nil(err) + assert.False(flags.HasLocked()) + } +} diff --git a/internal/unionstore/union_store.go b/internal/unionstore/union_store.go index 19c100936a..1a5f1a36b9 100644 --- a/internal/unionstore/union_store.go +++ b/internal/unionstore/union_store.go @@ -250,54 +250,7 @@ type Metrics struct { } var ( - _ MemBuffer = &MemDBWithContext{} _ MemBuffer = &PipelinedMemDB{} + _ MemBuffer = &rbtDBWithContext{} + _ MemBuffer = &artDBWithContext{} ) - -// MemDBWithContext wraps MemDB to satisfy the MemBuffer interface. -type MemDBWithContext struct { - *MemDB -} - -func NewMemDBWithContext() *MemDBWithContext { - return &MemDBWithContext{MemDB: newMemDB()} -} - -func (db *MemDBWithContext) Get(_ context.Context, k []byte) ([]byte, error) { - return db.MemDB.Get(k) -} - -func (db *MemDBWithContext) GetLocal(_ context.Context, k []byte) ([]byte, error) { - return db.MemDB.Get(k) -} - -func (db *MemDBWithContext) Flush(bool) (bool, error) { return false, nil } - -func (db *MemDBWithContext) FlushWait() error { return nil } - -// GetMemDB returns the inner MemDB -func (db *MemDBWithContext) GetMemDB() *MemDB { - return db.MemDB -} - -// BatchGet returns the values for given keys from the MemBuffer. -func (db *MemDBWithContext) BatchGet(ctx context.Context, keys [][]byte) (map[string][]byte, error) { - if db.Len() == 0 { - return map[string][]byte{}, nil - } - m := make(map[string][]byte, len(keys)) - for _, k := range keys { - v, err := db.Get(ctx, k) - if err != nil { - if tikverr.IsErrNotFound(err) { - continue - } - return nil, err - } - m[string(k)] = v - } - return m, nil -} - -// GetFlushMetrisc implements the MemBuffer interface. -func (db *MemDBWithContext) GetMetrics() Metrics { return Metrics{} } diff --git a/internal/unionstore/union_store_test.go b/internal/unionstore/union_store_test.go index 7eff6db5fe..24bfdb9538 100644 --- a/internal/unionstore/union_store_test.go +++ b/internal/unionstore/union_store_test.go @@ -44,7 +44,7 @@ import ( func TestUnionStoreGetSet(t *testing.T) { assert := assert.New(t) - store := newMemDB() + store := NewMemDB() us := NewUnionStore(NewMemDBWithContext(), &mockSnapshot{store}) err := store.Set([]byte("1"), []byte("1")) @@ -63,7 +63,7 @@ func TestUnionStoreGetSet(t *testing.T) { func TestUnionStoreDelete(t *testing.T) { assert := assert.New(t) - store := newMemDB() + store := NewMemDB() us := NewUnionStore(NewMemDBWithContext(), &mockSnapshot{store}) err := store.Set([]byte("1"), []byte("1")) @@ -82,7 +82,7 @@ func TestUnionStoreDelete(t *testing.T) { func TestUnionStoreSeek(t *testing.T) { assert := assert.New(t) - store := newMemDB() + store := NewMemDB() us := NewUnionStore(NewMemDBWithContext(), &mockSnapshot{store}) err := store.Set([]byte("1"), []byte("1")) @@ -115,7 +115,7 @@ func TestUnionStoreSeek(t *testing.T) { func TestUnionStoreIterReverse(t *testing.T) { assert := assert.New(t) - store := newMemDB() + store := NewMemDB() us := NewUnionStore(NewMemDBWithContext(), &mockSnapshot{store}) err := store.Set([]byte("1"), []byte("1")) diff --git a/txnkv/transaction/txn.go b/txnkv/transaction/txn.go index 2496d3f37d..674a242a1b 100644 --- a/txnkv/transaction/txn.go +++ b/txnkv/transaction/txn.go @@ -186,7 +186,7 @@ func NewTiKVTxn(store kvstore, snapshot *txnsnapshot.KVSnapshot, startTS uint64, RequestSource: snapshot.RequestSource, } if !options.PipelinedMemDB { - newTiKVTxn.us = unionstore.NewUnionStore(unionstore.NewMemDBWithContext(), snapshot) + newTiKVTxn.us = unionstore.NewUnionStore(unionstore.NewMemDB(), snapshot) return newTiKVTxn, nil } if err := newTiKVTxn.InitPipelinedMemDB(); err != nil {