From 104c0216bf8e55c9c5901ab89d899924f67d1969 Mon Sep 17 00:00:00 2001 From: Harshil Goel Date: Fri, 27 Sep 2024 07:53:05 +0530 Subject: [PATCH] added sharded map --- edgraph/server.go | 1 + go.mod | 2 + go.sum | 2 - posting/index.go | 21 +- posting/index_test.go | 1 + posting/list.go | 29 ++- posting/list_test.go | 6 + posting/lists.go | 12 +- posting/mvcc.go | 517 +++++++++++++++++++++++++++++++-------- posting/mvcc_test.go | 62 ++++- posting/oracle.go | 2 +- worker/draft.go | 2 +- worker/online_restore.go | 1 + x/keys.go | 2 +- 14 files changed, 546 insertions(+), 114 deletions(-) diff --git a/edgraph/server.go b/edgraph/server.go index 6561146c1a6..1f3cfa6f1eb 100644 --- a/edgraph/server.go +++ b/edgraph/server.go @@ -1411,6 +1411,7 @@ func (s *Server) doQuery(ctx context.Context, req *Request) (resp *api.Response, EncodingNs: uint64(l.Json.Nanoseconds()), TotalNs: uint64((time.Since(l.Start)).Nanoseconds()), } + //fmt.Println("====Query Resp", qc.req.Query, qc.req.StartTs, qc.req, string(resp.Json)) return resp, gqlErrs } diff --git a/go.mod b/go.mod index a458004f47a..151965ade27 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,7 @@ module github.com/dgraph-io/dgraph/v24 +replace github.com/dgraph-io/ristretto => /home/harshil/Projects/ristretto/ + go 1.22.6 require ( diff --git a/go.sum b/go.sum index 746ebf867b6..1ff5a4a7b14 100644 --- a/go.sum +++ b/go.sum @@ -151,8 +151,6 @@ github.com/dgraph-io/gqlparser/v2 v2.2.2 h1:CnxXOKL4EPguKqcGV/z4u4VoW5izUkOTIsNM github.com/dgraph-io/gqlparser/v2 v2.2.2/go.mod h1:MYS4jppjyx8b9tuUtjV7jU1UFZK6P9fvO8TsIsQtRKU= github.com/dgraph-io/graphql-transport-ws v0.0.0-20210511143556-2cef522f1f15 h1:X2NRsgAtVUAp2nmTPCq+x+wTcRRrj74CEpy7E0Unsl4= github.com/dgraph-io/graphql-transport-ws v0.0.0-20210511143556-2cef522f1f15/go.mod h1:7z3c/5w0sMYYZF5bHsrh8IH4fKwG5O5Y70cPH1ZLLRQ= -github.com/dgraph-io/ristretto v1.0.0 h1:SYG07bONKMlFDUYu5pEu3DGAh8c2OFNzKm6G9J4Si84= -github.com/dgraph-io/ristretto v1.0.0/go.mod h1:jTi2FiYEhQ1NsMmA7DeBykizjOuY88NhKBkepyu1jPc= github.com/dgraph-io/simdjson-go v0.3.0 h1:h71LO7vR4LHMPUhuoGN8bqGm1VNfGOlAG8BI6iDUKw0= github.com/dgraph-io/simdjson-go v0.3.0/go.mod h1:Otpysdjaxj9OGaJusn4pgQV7OFh2bELuHANq0I78uvY= github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= diff --git a/posting/index.go b/posting/index.go index 6ab2dbd91d8..e18cd41e788 100644 --- a/posting/index.go +++ b/posting/index.go @@ -217,11 +217,28 @@ func (txn *Txn) addIndexMutation(ctx context.Context, edge *pb.DirectedEdge, tok return err } - x.AssertTrue(plist != nil) if err = plist.addMutation(ctx, txn, edge); err != nil { return err } - ostats.Record(ctx, x.NumEdges.M(1)) + + //mpost := NewPosting(edge) + //mpost.StartTs = txn.StartTs + //if mpost.PostingType != pb.Posting_REF { + // edge.ValueId = fingerprintEdge(edge) + // mpost.Uid = edge.ValueId + //} + + ////fmt.Println("ADDING MUTATION", plist.mutationMap, key, edge) + //txn.addConflictKey(indexConflicKey(key, edge)) + + //plist.Lock() + //defer plist.Unlock() + //if err != plist.updateMutationLayer(mpost, false) { + // return errors.Wrapf(err, "cannot update mutation layer of key %s with value %+v", + // hex.EncodeToString(plist.key), mpost) + //} + + //ostats.Record(ctx, x.NumEdges.M(1)) return nil } diff --git a/posting/index_test.go b/posting/index_test.go index 249b14e4c21..0bacf3db6e2 100644 --- a/posting/index_test.go +++ b/posting/index_test.go @@ -157,6 +157,7 @@ func addMutation(t *testing.T, l *List, edge *pb.DirectedEdge, op uint32, } txn.Update() + txn.UpdateCachedKeys(commitTs) writer := NewTxnWriter(pstore) require.NoError(t, txn.CommitToDisk(writer, commitTs)) require.NoError(t, writer.Flush()) diff --git a/posting/list.go b/posting/list.go index c639ca5b19c..9e33a836c79 100644 --- a/posting/list.go +++ b/posting/list.go @@ -172,9 +172,10 @@ func (mm *MutableMap) listLen(readTs uint64) int { return 0 } - if mm.length == math.MaxInt { + if mm.length == math.MaxInt || readTs < mm.uidsHTime { count := 0 mm.iterate(func(ts uint64, pl *pb.PostingList) { + //fmt.Println(ts, pl) for _, mpost := range pl.Postings { if mpost.Op == Del { count -= 1 @@ -186,9 +187,11 @@ func (mm *MutableMap) listLen(readTs uint64) int { return count } + //fmt.Println("here", mm.length) count := mm.length if mm.curList != nil { for _, mpost := range mm.curList.Postings { + fmt.Println(mpost) if mpost.Op == Del { count -= 1 } else { @@ -254,6 +257,7 @@ func (mm *MutableMap) iterate(f func(ts uint64, pl *pb.PostingList), readTs uint deleteMarker := mm.populateDeleteAll(readTs) mm._iterate(func(ts uint64, pl *pb.PostingList) { + //fmt.Println("********MUTABLE MAP ITERATE", ts, readTs, pl) if ts >= deleteMarker && ts <= readTs { f(ts, pl) } @@ -719,6 +723,17 @@ func (l *List) addMutation(ctx context.Context, txn *Txn, t *pb.DirectedEdge) er return l.addMutationInternal(ctx, txn, t) } +func indexConflicKey(key []byte, t *pb.DirectedEdge) uint64 { + getKey := func(key []byte, uid uint64) uint64 { + // Instead of creating a string first and then doing a fingerprint, let's do a fingerprint + // here to save memory allocations. + // Not entirely sure about effect on collision chances due to this simple XOR with uid. + return farm.Fingerprint64(key) ^ uid + } + + return getKey(key, t.ValueId) +} + func GetConflictKey(pk x.ParsedKey, key []byte, t *pb.DirectedEdge) uint64 { getKey := func(key []byte, uid uint64) uint64 { // Instead of creating a string first and then doing a fingerprint, let's do a fingerprint @@ -888,7 +903,6 @@ func (l *List) setMutation(startTs uint64, data []byte) { l.mutationMap = newMutableMap() } l.mutationMap.set(startTs, pl) - if pl.CommitTs != 0 { l.maxTs = x.Max(l.maxTs, pl.CommitTs) } @@ -919,6 +933,7 @@ func (l *List) Iterate(readTs uint64, afterUid uint64, f func(obj *pb.Posting) e // If greater than zero, this timestamp must thus be greater than l.minTs. func (l *List) pickPostings(readTs uint64) (uint64, []*pb.Posting) { // This function would return zero ts for entries above readTs. + // either way, effective ts is returning ts. effective := func(start, commit uint64) uint64 { if commit > 0 && commit <= readTs { // Has been committed and below the readTs. @@ -958,6 +973,7 @@ func (l *List) iterate(readTs uint64, afterUid uint64, f func(obj *pb.Posting) e // mposts is the list of mutable postings deleteBelowTs, mposts := l.pickPostings(readTs) + //fmt.Println(mposts, deleteBelowTs, l.plist) if readTs < l.minTs { return errors.Errorf("readTs: %d less than minTs: %d for key: %q", readTs, l.minTs, l.key) } @@ -1115,6 +1131,7 @@ func (l *List) getPostingAndLength(readTs, afterUid, uid uint64) (int, bool, *pb var count int var found bool var post *pb.Posting + err := l.iterate(readTs, afterUid, func(p *pb.Posting) error { if p.Uid == uid { post = p @@ -1258,6 +1275,7 @@ func (l *List) Rollup(alloc *z.Allocator, readTs uint64) ([]*bpb.KV, error) { return bytes.Compare(kvs[i].Key, kvs[j].Key) <= 0 }) + //fmt.Println("ROLLING UP", l.key, out.plist, kv.Version) x.PrintRollup(out.plist, out.parts, l.key, kv.Version) x.VerifyPostingSplits(kvs, out.plist, out.parts, l.key) return kvs, nil @@ -1677,19 +1695,23 @@ func (l *List) findStaticValue(readTs uint64) *pb.PostingList { if l.mutationMap == nil { // If mutation map is empty, check if there is some data, and return it. if l.plist != nil && len(l.plist.Postings) > 0 { + //fmt.Println("nil map plist") return l.plist } + //fmt.Println("nil map nil") return nil } // Return readTs is if it's present in the mutation. It's going to be the latest value. if l.mutationMap.curList != nil && l.mutationMap.curTime == readTs { + //fmt.Println("curlist", l.mutationMap.curList) return l.mutationMap.curList } // If maxTs < readTs then we need to read maxTs if l.maxTs <= readTs { if mutation := l.mutationMap.get(l.maxTs); mutation != nil { + //fmt.Println("mutation", mutation) return mutation } } @@ -1705,11 +1727,13 @@ func (l *List) findStaticValue(readTs uint64) *pb.PostingList { } }, readTs) if mutation != nil { + //fmt.Println("iterate", mutation) return mutation } // If we reach here, that means that there was no entry in mutation map which is less than readTs. That // means we need to return l.plist + //fmt.Println("got nothing", l.plist, l.plist != nil) return l.plist } @@ -1871,6 +1895,7 @@ func (l *List) findPosting(readTs uint64, uid uint64) (found bool, pos *pb.Posti // Iterate starts iterating after the given argument, so we pass UID - 1 // TODO Find what happens when uid = math.MaxUint64 searchFurther, pos := l.mutationMap.findPosting(readTs, uid) + //fmt.Println("FIND POSTING", readTs, "key", l.key, "mutationMap:", l.mutationMap, "plist:", l.plist, uid, searchFurther, pos) if pos != nil { return true, pos, nil } diff --git a/posting/list_test.go b/posting/list_test.go index 3066052a8a2..0649f7073a2 100644 --- a/posting/list_test.go +++ b/posting/list_test.go @@ -134,6 +134,7 @@ func TestGetSinglePosting(t *testing.T) { res, err := l.StaticValue(1) require.NoError(t, err) + //fmt.Println(res, res == nil) require.Equal(t, res == nil, true) l.plist = create_pl(1, 1) @@ -225,6 +226,7 @@ func TestAddMutation(t *testing.T) { func getFirst(t *testing.T, l *List, readTs uint64) (res pb.Posting) { require.NoError(t, l.Iterate(readTs, 0, func(p *pb.Posting) error { + //fmt.Println("INSIDE ITERATE", p) res = *p return ErrStopIteration })) @@ -233,6 +235,7 @@ func getFirst(t *testing.T, l *List, readTs uint64) (res pb.Posting) { func checkValue(t *testing.T, ol *List, val string, readTs uint64) { p := getFirst(t, ol, readTs) + //fmt.Println("HERE", val, string(p.Value), p, ol, p.Uid) require.Equal(t, uint64(math.MaxUint64), p.Uid) // Cast to prevent overflow. require.EqualValues(t, val, p.Value) } @@ -532,6 +535,8 @@ func TestReadSingleValue(t *testing.T) { kvs, err := ol.Rollup(nil, txn.StartTs-3) require.NoError(t, err) require.NoError(t, writePostingListToDisk(kvs)) + // Delete item from global cache before reading, as we are not updating the cache in the test + memoryLayer.Del(z.MemHash(key)) ol, err = getNew(key, ps, math.MaxUint64) require.NoError(t, err) } @@ -541,6 +546,7 @@ func TestReadSingleValue(t *testing.T) { j = ol.minTs } for ; j < i+6; j++ { + ResetCache() tx := NewTxn(j) k, err := tx.cache.GetSinglePosting(key) require.NoError(t, err) diff --git a/posting/lists.go b/posting/lists.go index 9a20a917900..d1a398036d2 100644 --- a/posting/lists.go +++ b/posting/lists.go @@ -143,7 +143,7 @@ func (vc *viLocalCache) GetWithLockHeld(key []byte) (rval index.Value, rerr erro func (vc *viLocalCache) GetValueFromPostingList(pl *List) (rval index.Value, rerr error) { value := pl.findStaticValue(vc.delegate.startTs) - if value == nil { + if value == nil || len(value.Postings) == 0 { return nil, ErrNoValue } @@ -312,8 +312,9 @@ func (lc *LocalCache) getInternal(key []byte, readFromDisk bool) (*List, error) } } else { pl = &List{ - key: key, - plist: new(pb.PostingList), + key: key, + plist: new(pb.PostingList), + mutationMap: newMutableMap(), } } @@ -336,6 +337,7 @@ func (lc *LocalCache) GetSinglePosting(key []byte) (*pb.PostingList, error) { pl := &pb.PostingList{} if delta, ok := lc.deltas[string(key)]; ok && len(delta) > 0 { + //fmt.Println("GETTING FROM DELTAS") err := pl.Unmarshal(delta) lc.RUnlock() return pl, err @@ -356,6 +358,7 @@ func (lc *LocalCache) GetSinglePosting(key []byte) (*pb.PostingList, error) { // If both pl and err are empty, that means that there was no data in local cache, hence we should // read the data from badger. if pl != nil || err != nil { + //fmt.Println("GETTING POSTING1", lc.startTs, pl) return pl, err } @@ -372,6 +375,7 @@ func (lc *LocalCache) GetSinglePosting(key []byte) (*pb.PostingList, error) { return pl.Unmarshal(val) }) + //fmt.Println("GETTING POSTING FROM BADGER", lc.startTs, pl) return pl, err } @@ -395,6 +399,8 @@ func (lc *LocalCache) GetSinglePosting(key []byte) (*pb.PostingList, error) { } } pl.Postings = pl.Postings[:idx] + //pk, _ := x.Parse([]byte(key)) + //fmt.Println("====Getting single posting", lc.startTs, pk, pl.Postings) return pl, nil } diff --git a/posting/mvcc.go b/posting/mvcc.go index 3fc11ad8cc1..440ef39bf40 100644 --- a/posting/mvcc.go +++ b/posting/mvcc.go @@ -19,13 +19,16 @@ package posting import ( "bytes" "encoding/hex" + "fmt" + "math" + "runtime" "strconv" - "strings" "sync" "sync/atomic" "time" "github.com/golang/glog" + "github.com/golang/protobuf/proto" "github.com/pkg/errors" "github.com/dgraph-io/badger/v4" @@ -33,6 +36,7 @@ import ( "github.com/dgraph-io/dgo/v240/protos/api" "github.com/dgraph-io/dgraph/v24/protos/pb" "github.com/dgraph-io/dgraph/v24/x" + "github.com/dgraph-io/ristretto" "github.com/dgraph-io/ristretto/z" ) @@ -62,12 +66,7 @@ type CachePL struct { count int list *List lastUpdate uint64 -} - -type GlobalCache struct { - sync.RWMutex - - items map[string]*CachePL + lastRead time.Time } var ( @@ -84,10 +83,12 @@ var ( priorityKeys: make([]*pooledKeys, 2), } - globalCache = &GlobalCache{items: make(map[string]*CachePL, 100)} + memoryLayer = initMemoryLayer() + numShards = 256 ) func init() { + runtime.SetCPUProfileRate(200) x.AssertTrue(len(IncrRollup.priorityKeys) == 2) for i := range IncrRollup.priorityKeys { IncrRollup.priorityKeys[i] = &pooledKeys{ @@ -133,13 +134,9 @@ func (ir *incrRollupi) rollUpKey(writer *TxnWriter, key []byte) error { } RemoveCacheFor(key) - - globalCache.Lock() - val, ok := globalCache.items[string(key)] - if ok { - val.list = nil - } - globalCache.Unlock() + //pk, _ := x.Parse(key) + //fmt.Println("====Setting cache delete rollup", ts, pk) + memoryLayer.Del(z.MemHash(key)) // TODO Update cache with rolled up results // If we do a rollup, we typically won't need to update the key in cache. // The only caveat is that the key written by rollup would be written at +1 @@ -174,6 +171,26 @@ func (ir *incrRollupi) addKeyToBatch(key []byte, priority int) { } } +func (ir *incrRollupi) mem(getNewTs func(bool) uint64) { + forceRollupTick := time.NewTicker(500 * time.Millisecond) + defer forceRollupTick.Stop() + deleteCacheTick := time.NewTicker(1 * time.Second) + defer deleteCacheTick.Stop() + + for { + select { + case <-deleteCacheTick.C: + memoryLayer.deleteOldItems(ir.getNewTs(false)) + case <-forceRollupTick.C: + t := 10000 + memoryLayer.insert += t + if memoryLayer.insert > 5*t { + memoryLayer.insert = 5 * t + } + } + } +} + // Process will rollup batches of 64 keys in a go routine. func (ir *incrRollupi) Process(closer *z.Closer, getNewTs func(bool) uint64) { ir.getNewTs = getNewTs @@ -192,6 +209,8 @@ func (ir *incrRollupi) Process(closer *z.Closer, getNewTs func(bool) uint64) { forceRollupTick := time.NewTicker(500 * time.Millisecond) defer forceRollupTick.Stop() + ir.mem(getNewTs) + doRollup := func(batch *[][]byte, priority int) { currTs := time.Now().Unix() for _, key := range *batch { @@ -317,6 +336,7 @@ func (txn *Txn) CommitToDisk(writer *TxnWriter, commitTs uint64) error { for ; idx < len(keys); idx++ { key := keys[idx] data := cache.deltas[key] + //fmt.Println("--------------------------------------------------------HUI") if len(data) == 0 { continue } @@ -348,9 +368,7 @@ func ResetCache() { if lCache != nil { lCache.Clear() } - globalCache.Lock() - globalCache.items = make(map[string]*CachePL) - globalCache.Unlock() + memoryLayer.Clear() } // RemoveCacheFor will delete the list corresponding to the given key. @@ -361,6 +379,149 @@ func RemoveCacheFor(key []byte) { } } +type setItems struct { + keyHash uint64 + list *List + readTs uint64 +} + +type MemoryLayer struct { + shards []*lockedMap + setBuf chan *setItems + skull *ristretto.Skull[uint64, struct{}] + + insert int + numCacheRead int + numCacheReadFails int + numDisksRead int + numCacheSave int +} + +func initMemoryLayer() *MemoryLayer { + sm := &MemoryLayer{ + shards: make([]*lockedMap, numShards), + setBuf: make(chan *setItems, 32*1024), + skull: ristretto.GetSkull[uint64, struct{}](), + } + for i := range sm.shards { + sm.shards[i] = newLockedMap() + } + return sm +} + +func (sm *MemoryLayer) get(key uint64) (*CachePL, bool) { + return sm.shards[key%uint64(numShards)].get(key) +} + +func (sm *MemoryLayer) set(key uint64, i *CachePL) { + if i == nil { + // If item is nil make this Set a no-op. + return + } + + sm.shards[key%uint64(numShards)].set(key, i) +} + +func (sm *MemoryLayer) del(key uint64) { + sm.shards[key%uint64(numShards)].del(key) +} + +func (sm *MemoryLayer) Get(key uint64) (*CachePL, bool) { + return sm.shards[key%uint64(numShards)].Get(key) +} + +func (sm *MemoryLayer) Set(key uint64, i *CachePL) { + if i == nil { + // If item is nil make this Set a no-op. + return + } + + sm.shards[key%uint64(numShards)].Set(key, i) +} + +func (sm *MemoryLayer) Del(key uint64) { + sm.shards[key%uint64(numShards)].Del(key) + sm.skull.Del(key) +} + +func (sm *MemoryLayer) UnlockKey(key uint64) { + sm.shards[key%uint64(numShards)].Unlock() +} + +func (sm *MemoryLayer) LockKey(key uint64) { + sm.shards[key%uint64(numShards)].Lock() +} + +func (sm *MemoryLayer) RLockKey(key uint64) { + sm.shards[key%uint64(numShards)].RLock() +} + +func (sm *MemoryLayer) RUnlockKey(key uint64) { + sm.shards[key%uint64(numShards)].RUnlock() +} + +func (sm *MemoryLayer) Clear() { + for i := 0; i < numShards; i++ { + sm.shards[i].Clear() + } +} + +type lockedMap struct { + sync.RWMutex + data map[uint64]*CachePL +} + +func newLockedMap() *lockedMap { + return &lockedMap{ + data: make(map[uint64]*CachePL), + } +} + +func (m *lockedMap) get(key uint64) (*CachePL, bool) { + item, ok := m.data[key] + return item, ok +} + +func (m *lockedMap) Get(key uint64) (*CachePL, bool) { + m.RLock() + defer m.RUnlock() + item, ok := m.data[key] + return item, ok +} + +func (m *lockedMap) set(key uint64, i *CachePL) { + if i == nil { + // If the item is nil make this Set a no-op. + return + } + + m.data[key] = i +} + +func (m *lockedMap) Set(key uint64, i *CachePL) { + m.Lock() + defer m.Unlock() + m.set(key, i) +} + +func (m *lockedMap) del(key uint64) { + delete(m.data, key) +} + +func (m *lockedMap) Del(key uint64) { + m.Lock() + if l, ok := m.data[key]; ok && l != nil { + l.list = nil + } + m.Unlock() +} + +func (m *lockedMap) Clear() { + m.Lock() + m.data = make(map[uint64]*CachePL) + m.Unlock() +} + func NewCachePL() *CachePL { return &CachePL{ count: 0, @@ -369,6 +530,54 @@ func NewCachePL() *CachePL { } } +func checkForRollup(key []byte, l *List) { + deltaCount := l.mutationMap.len() + // If deltaCount is high, send it to high priority channel instead. + if deltaCount > 500 { + IncrRollup.addKeyToBatch(key, 0) + } +} + +func (ml *MemoryLayer) updateItemInCache(key string, pk x.ParsedKey, delta []byte, startTs, commitTs uint64) { + if commitTs == 0 { + return + } + + p := new(pb.PostingList) + x.Check(p.Unmarshal(delta)) + //fmt.Println("======COMMITTING", startTs, commitTs, pk, p) + + keyHash := z.MemHash([]byte(key)) + // TODO under the same lock + ml.LockKey(keyHash) + defer ml.UnlockKey(keyHash) + + a := 1 + if a == 1 { + delete(ml.shards[keyHash%uint64(numShards)].data, keyHash) + return + } + + val, ok := ml.get(keyHash) + if !ok { + val = NewCachePL() + val.lastUpdate = commitTs + ml.set(keyHash, val) + return + } + + val.lastUpdate = commitTs + val.count -= 1 + + if val.list != nil { + p := new(pb.PostingList) + x.Check(p.Unmarshal(delta)) + val.list.setMutationAfterCommit(startTs, commitTs, p, true) + checkForRollup([]byte(key), val.list) + //fmt.Println("====Setting cache list", commitTs, pk, p, val.list.mutationMap, val.list.key) + } +} + // RemoveCachedKeys will delete the cached list by this txn. func (txn *Txn) UpdateCachedKeys(commitTs uint64) { if txn == nil || txn.cache == nil { @@ -377,34 +586,12 @@ func (txn *Txn) UpdateCachedKeys(commitTs uint64) { for key, delta := range txn.cache.deltas { RemoveCacheFor([]byte(key)) + pk, _ := x.Parse([]byte(key)) if !ShouldGoInCache(pk) { continue } - globalCache.Lock() - val, ok := globalCache.items[key] - if !ok { - val = NewCachePL() - val.lastUpdate = commitTs - globalCache.items[key] = val - } - if commitTs != 0 { - // TODO Delete this if the values are too old in an async thread - val.lastUpdate = commitTs - } - if !ok { - globalCache.Unlock() - continue - } - - val.count -= 1 - - if commitTs != 0 && val.list != nil { - p := new(pb.PostingList) - x.Check(p.Unmarshal(delta)) - val.list.setMutationAfterCommit(txn.StartTs, commitTs, p, true) - } - globalCache.Unlock() + memoryLayer.updateItemInCache(key, pk, delta, txn.StartTs, commitTs) } } @@ -432,6 +619,7 @@ func ReadPostingList(key []byte, it *badger.Iterator) (*List, error) { // lists ended up being rolled-up multiple times. This issue was caught by the // uid-set Jepsen test. pk, err := x.Parse(key) + //fmt.Println("READING KEY", key, pk) if err != nil { return nil, errors.Wrapf(err, "while reading posting list with key [%v]", key) } @@ -448,6 +636,7 @@ func ReadPostingList(key []byte, it *badger.Iterator) (*List, error) { l := new(List) l.key = key l.plist = new(pb.PostingList) + l.minTs = 0 // We use the following block of code to trigger incremental rollup on this key. deltaCount := 0 @@ -476,19 +665,21 @@ func ReadPostingList(key []byte, it *badger.Iterator) (*List, error) { switch item.UserMeta() { case BitEmptyPosting: - l.minTs = item.Version() return l, nil case BitCompletePosting: if err := unmarshalOrCopy(l.plist, item); err != nil { return nil, err } - l.minTs = item.Version() + l.minTs = item.Version() // No need to do Next here. The outer loop can take care of skipping // more versions of the same key. return l, nil case BitDeltaPosting: err := item.Value(func(val []byte) error { + if l.mutationMap == nil { + l.mutationMap = newMutableMap() + } pl := &pb.PostingList{} if err := pl.Unmarshal(val); err != nil { return err @@ -540,61 +731,133 @@ func (c *CachePL) Set(l *List, readTs uint64) { } func ShouldGoInCache(pk x.ParsedKey) bool { - return (!pk.IsData() && strings.HasSuffix(pk.Attr, "dgraph.type")) + //return !pk.IsData() + return true + //return false } func PostingListCacheEnabled() bool { - return lCache != nil + return false + //return lCache != nil } -func getNew(key []byte, pstore *badger.DB, readTs uint64) (*List, error) { - if PostingListCacheEnabled() { - l, ok := lCache.Get(key) - if ok && l != nil { - // No need to clone the immutable layer or the key since mutations will not modify it. - lCopy := &List{ - minTs: l.minTs, - maxTs: l.maxTs, - key: key, - plist: l.plist, - } - l.RLock() - lCopy.mutationMap = l.mutationMap.clone() - l.RUnlock() - return lCopy, nil - } +func (ml *MemoryLayer) Process(i *setItems) { + if ml.insert < 0 { + return } + ml.insert -= 1 + victims, add := ml.skull.Set(i.keyHash, int64(i.list.ApproxLen()/100)) + //fmt.Println(add, i.list.ApproxLen()) - if pstore.IsClosed() { - return nil, badger.ErrDBClosed + if add { + ml.LockKey(i.keyHash) + i.list.RLock() + cacheItem, ok := ml.get(i.keyHash) + if !ok { + cacheItemNew := NewCachePL() + cacheItemNew.count = 1 + cacheItemNew.list = copyList(i.list) + cacheItemNew.lastUpdate = i.list.maxTs + ml.set(i.keyHash, cacheItemNew) + } else { + // Only set l to the cache if readTs >= latestTs, which implies that l is + // the latest version of the PL. We also check that we're reading a version + // from Badger, which is higher than the write registered by the cache. + + //fmt.Println("====Setting cache", readTs, pk, l.mutationMap) + cacheItem.Set(copyList(i.list), i.readTs) + } + ml.numCacheSave += 1 + //allV, _ := i.list.AllValues(i.readTs) + //uids, _ := i.list.Uids(ListOptions{ReadTs: i.readTs}) + //fmt.Println("====Setting into cache", i.readTs, i.list.key, i.list.mutationMap, allV, uids) + i.list.RUnlock() + + //idx := int(i.keyHash % uint64(numShards)) + //if len(ml.shards[idx].data) > 500 { + // for keyHash, pl := range ml.shards[idx].data { + // if pl.lastRead < i.readTs-100 { + // delete(ml.shards[idx].data, keyHash) + // } + // } + //} + ml.UnlockKey(i.keyHash) + } + + for _, vic := range victims { + ml.LockKey(vic.Key) + delete(ml.shards[vic.Key%uint64(numShards)].data, vic.Key) + ml.UnlockKey(vic.Key) } +} - pk, _ := x.Parse(key) +func (ml *MemoryLayer) deleteOldItems(ts uint64) { + fmt.Println("Deleting old items", ml.numCacheRead, ml.numDisksRead, ml.numCacheSave, ml.numCacheReadFails, float64(ml.numCacheRead)/float64(ml.numDisksRead), ml.insert) + lb := 0 + la := 0 + t1 := time.Now() + defer func() { + fmt.Println("Done deleting old items", lb, la, time.Since(t1)) + }() - if ShouldGoInCache(pk) { - globalCache.Lock() - cacheItem, ok := globalCache.items[string(key)] - if !ok { - cacheItem = NewCachePL() - globalCache.items[string(key)] = cacheItem + //ml.skull.ExpiryMap.SkullCleanup(func(keyHash uint64) { + // ml.LockKey(keyHash) + // delete(ml.shards[keyHash%uint64(numShards)].data, keyHash) + // ml.UnlockKey(keyHash) + //}, ml.skull.CachePolicy) + + for i := 0; i < numShards; i++ { + ml.shards[i].Lock() + lb += len(ml.shards[i].data) + for keyHash, pl := range ml.shards[i].data { + if time.Since(pl.lastRead) > 5*time.Second { + delete(ml.shards[i].data, keyHash) + ml.skull.Del(keyHash) + } + if len(ml.shards[i].data) < 500 { // Keeps like 200k entries after this + break + } } - cacheItem.count += 1 + la += len(ml.shards[i].data) + ml.shards[i].Unlock() + } +} - // We use badger subscription to invalidate the cache. For every write we make the value - // corresponding to the key in the cache to nil. So, if we get some non-nil value from the cache - // then it means that no writes have happened after the last set of this key in the cache. - if ok { - if cacheItem.list != nil && cacheItem.list.minTs <= readTs { - cacheItem.list.RLock() - lCopy := copyList(cacheItem.list) - cacheItem.list.RUnlock() - globalCache.Unlock() - return lCopy, nil - } +func (ml *MemoryLayer) saveInCache(keyHash, readTs uint64, l *List) { + ml.Process(&setItems{ + keyHash: keyHash, + readTs: readTs, + list: l, + }) +} + +func (ml *MemoryLayer) readFromCache(key []byte, keyHash, readTs uint64) *List { + ml.RLockKey(keyHash) + + ml.skull.Get(keyHash) + cacheItem, ok := ml.get(keyHash) + + if ok { + cacheItem.count += 1 + cacheItem.lastRead = time.Now() + if cacheItem.list != nil && cacheItem.list.minTs <= readTs { + cacheItem.list.RLock() + lCopy := copyList(cacheItem.list) + cacheItem.list.RUnlock() + ml.RUnlockKey(keyHash) + checkForRollup(key, lCopy) + //allV, _ := lCopy.AllValues(readTs) + //uids, _ := lCopy.Uids(ListOptions{ReadTs: readTs}) + //fmt.Println("====Getting cache", readTs, lCopy.key, lCopy.mutationMap, allV, uids) + return lCopy } - globalCache.Unlock() } + ml.RUnlockKey(keyHash) + return nil +} +func (ml *MemoryLayer) readFromDisk(key []byte, pstore *badger.DB, readTs uint64) (*List, error) { + ml.numDisksRead += 1 txn := pstore.NewTransactionAt(readTs, false) defer txn.Discard() @@ -607,31 +870,85 @@ func getNew(key []byte, pstore *badger.DB, readTs uint64) (*List, error) { defer itr.Close() itr.Seek(key) l, err := ReadPostingList(key, itr) + //fmt.Println("=============GETTING DISK", key, l.mutationMap, l.plist) if err != nil { return l, err } + return l, nil +} - // Only set l to the cache if readTs >= latestTs, which implies that l is - // the latest version of the PL. We also check that we're reading a version - // from Badger, which is higher than the write registered by the cache. - if ShouldGoInCache(pk) { - globalCache.Lock() - l.RLock() - cacheItem, ok := globalCache.items[string(key)] - if !ok { - cacheItemNew := NewCachePL() - cacheItemNew.count = 1 - cacheItemNew.list = copyList(l) - cacheItemNew.lastUpdate = l.maxTs - globalCache.items[string(key)] = cacheItemNew +func (ml *MemoryLayer) ReadData(key []byte, pstore *badger.DB, readTs uint64) (*List, error) { + pk, _ := x.Parse(key) + var keyHash uint64 + + gic := ShouldGoInCache(pk) + if gic { + keyHash = z.MemHash(key) + l := ml.readFromCache(key, keyHash, readTs) + if l != nil { + //fmt.Println(pk, pk.IsData()) + ml.numCacheRead += 1 + return l, nil } else { - cacheItem.Set(copyList(l), readTs) + ml.numCacheReadFails += 1 } - l.RUnlock() - globalCache.Unlock() + l, err := ml.readFromDisk(key, pstore, math.MaxUint64) + //fmt.Println("READING FROM DISK", l.minTs, readTs) + if err != nil { + return nil, err + } + ml.saveInCache(keyHash, readTs, l) + if l.minTs == 0 || readTs >= l.minTs { + return l, nil + } + } + + l, err := ml.readFromDisk(key, pstore, readTs) + if err != nil { + return nil, err + } + + return l, nil +} + +func getNew(key []byte, pstore *badger.DB, readTs uint64) (*List, error) { + //fmt.Println("Get new", key) + if PostingListCacheEnabled() { + l, ok := lCache.Get(key) + if ok && l != nil { + // No need to clone the immutable layer or the key since mutations will not modify it. + memoryLayer.numCacheRead += 1 + lCopy := &List{ + minTs: l.minTs, + maxTs: l.maxTs, + key: key, + plist: l.plist, + } + l.RLock() + if l.mutationMap != nil { + lCopy.mutationMap = newMutableMap() + for ts, pl := range l.mutationMap.oldList { + lCopy.mutationMap.oldList[ts] = proto.Clone(pl).(*pb.PostingList) + } + } + l.RUnlock() + return lCopy, nil + } else { + memoryLayer.numCacheReadFails += 1 + } + } + + if pstore.IsClosed() { + return nil, badger.ErrDBClosed + } + + l, err := memoryLayer.ReadData(key, pstore, readTs) + if err != nil { + return l, err } if PostingListCacheEnabled() { + memoryLayer.numCacheSave += 1 lCache.Set(key, l, 0) } diff --git a/posting/mvcc_test.go b/posting/mvcc_test.go index 0291d93947c..02b35579a1f 100644 --- a/posting/mvcc_test.go +++ b/posting/mvcc_test.go @@ -19,6 +19,7 @@ package posting import ( "context" "math" + "math/rand" "testing" "time" @@ -89,12 +90,66 @@ func TestCacheAfterDeltaUpdateRecieved(t *testing.T) { // Read key at timestamp 10. Make sure cache is not updated by this, as there is a later read. l, err := GetNoStore(key, 10) require.NoError(t, err) - require.Equal(t, l.mutationMap.len(), 0) + require.Equal(t, l.mutationMap.listLen(10), 0) // Read at 20 should show the value l1, err := GetNoStore(key, 20) require.NoError(t, err) - require.Equal(t, l1.mutationMap.len(), 1) + require.Equal(t, l1.mutationMap.listLen(20), 1) +} + +func BenchmarkTestCache(b *testing.B) { + //lCache, _ = ristretto.NewCache[[]byte, *List](&ristretto.Config[[]byte, *List]{ + // // Use 5% of cache memory for storing counters. + // NumCounters: int64(1000 * (1 << 20) * 0.05 * 2), + // MaxCost: int64(1000 * (1 << 20) * 0.95), + // BufferItems: 64, + // Metrics: true, + // Cost: func(val *List) int64 { + // return 0 + // }, + //}) + + attr := x.GalaxyAttr("cache") + keys := make([][]byte, 0) + N := 10000 + txn := Oracle().RegisterStartTs(1) + + for i := 1; i < N; i++ { + key := x.DataKey(attr, uint64(i)) + keys = append(keys, key) + edge := &pb.DirectedEdge{ + ValueId: 2, + Attr: attr, + Entity: 1, + Op: pb.DirectedEdge_SET, + } + l, _ := GetNoStore(key, 1) + // No index entries added here as we do not call AddMutationWithIndex. + txn.cache.SetIfAbsent(string(l.key), l) + err := l.addMutation(context.Background(), txn, edge) + if err != nil { + panic(err) + } + } + txn.Update() + writer := NewTxnWriter(pstore) + err := txn.CommitToDisk(writer, 2) + if err != nil { + panic(err) + } + writer.Flush() + + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + key := keys[rand.Intn(N-1)] + _, err = getNew(key, pstore, math.MaxUint64) + if err != nil { + panic(err) + } + } + }) + } func TestRollupTimestamp(t *testing.T) { @@ -153,6 +208,8 @@ func TestPostingListRead(t *testing.T) { writer := NewTxnWriter(pstore) require.NoError(t, writer.SetAt(key, []byte{}, BitEmptyPosting, 6)) require.NoError(t, writer.Flush()) + // Delete the key from cache as we have just updated it + memoryLayer.Del(z.MemHash(key)) assertLength(7, 0) addEdgeToUID(t, attr, 1, 4, 7, 8) @@ -165,6 +222,7 @@ func TestPostingListRead(t *testing.T) { writer = NewTxnWriter(pstore) require.NoError(t, writer.SetAt(key, data, BitCompletePosting, 10)) require.NoError(t, writer.Flush()) + memoryLayer.Del(z.MemHash(key)) assertLength(10, 0) addEdgeToUID(t, attr, 1, 5, 11, 12) diff --git a/posting/oracle.go b/posting/oracle.go index bc03dc9a861..c74190da8ce 100644 --- a/posting/oracle.go +++ b/posting/oracle.go @@ -105,7 +105,7 @@ func (vt *viTxn) GetWithLockHeld(key []byte) (rval index.Value, rerr error) { func (vt *viTxn) GetValueFromPostingList(pl *List) (rval index.Value, rerr error) { value := pl.findStaticValue(vt.delegate.StartTs) - if value == nil { + if value == nil || len(value.Postings) == 0 { //fmt.Println("DIFF", val, err, nil, badger.ErrKeyNotFound) return nil, ErrNoValue } diff --git a/worker/draft.go b/worker/draft.go index c50fa0b818b..9232fde0e21 100644 --- a/worker/draft.go +++ b/worker/draft.go @@ -358,7 +358,7 @@ func (n *node) applyMutations(ctx context.Context, proposal *pb.Proposal) (rerr // TODO: Revisit this when we work on posting cache. Clear entire cache. // We don't want to drop entire cache, just due to one namespace. - // posting.ResetCache() + posting.ResetCache() return nil } diff --git a/worker/online_restore.go b/worker/online_restore.go index 78c07337844..14bf367ba63 100644 --- a/worker/online_restore.go +++ b/worker/online_restore.go @@ -410,6 +410,7 @@ func handleRestoreProposal(ctx context.Context, req *pb.RestoreRequest, pidx uin return errors.Wrapf(err, "cannot load schema after restore") } + posting.ResetCache() ResetAclCache() // Reset gql schema only when the restore is not partial, so that after this restore diff --git a/x/keys.go b/x/keys.go index b55f0830e45..f297f535f68 100644 --- a/x/keys.go +++ b/x/keys.go @@ -307,7 +307,7 @@ func (p ParsedKey) String() string { } else if p.IsCountOrCountRev() { return fmt.Sprintf("UID: %v, Attr: %v, IsCount/Ref: true, Count: %v", p.Uid, p.Attr, p.Count) } else { - return fmt.Sprintf("UID: %v, Attr: %v, Data key", p.Uid, p.Attr) + return fmt.Sprintf("UID: %v, Attr: %v, Data key, prefix; %v, byte: %v", p.Uid, p.Attr, p.bytePrefix, p.ByteType) } }