From f823dc4a5031a0404648c7e004c65d37bdf3f38d Mon Sep 17 00:00:00 2001 From: Karl McGuire Date: Fri, 4 Oct 2019 15:56:02 -0400 Subject: [PATCH] clean and update tests to 100% coverage (#87) * cleaned up metrics * readme fix * documentation * start of test revamp * policy tests * finished policy tests * metrics testing * full test coverage * updated readme * stress testing * ci fixes * fix unused variable * remove unused field --- .github/CODEOWNERS | 2 +- .github/workflows/ci.yml | 5 +- README.md | 8 +- cache.go | 134 ++++---- cache_test.go | 669 +++++++++++++-------------------------- policy.go | 154 +-------- policy_test.go | 362 +++++++++++++++++---- ring.go | 97 ++---- ring_test.go | 164 +++------- sim/sim_test.go | 13 +- sketch.go | 84 ++--- sketch_test.go | 151 ++++----- store_test.go | 173 +++++----- stress_test.go | 156 +++++++++ 14 files changed, 1001 insertions(+), 1171 deletions(-) create mode 100644 stress_test.go diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 8525df51..7eaa4c88 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,4 +1,4 @@ # CODEOWNERS info: https://help.github.com/en/articles/about-code-owners # Owners are automatically requested for review for PRs that changes code # that they own. -* @manishrjain @jarifibrahim +* @manishrjain @jarifibrahim @karlmcguire diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 553f8020..9bfc360b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,4 +1,4 @@ -name: CI Workflow +name: tests on: [push, pull_request] jobs: @@ -15,4 +15,5 @@ jobs: with: go-version: ${{ matrix.go-version }} - run: go fmt ./... - - run: go test -v -race ./... + - run: go test -race ./... + - run: go test -v ./... diff --git a/README.md b/README.md index 799eda1a..67891530 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,9 @@ # Ristretto -[![Go Doc](https://img.shields.io/badge/godoc-reference-blue.svg)](http://godoc.org/github.com/golang-standards/project-layout) -[![Go Report Card](https://img.shields.io/badge/go%20report-A%2B-green)](https://goreportcard.com/report/github.com/dgraph-io/ristretto) -[![Coverage](https://img.shields.io/badge/coverage-79%25-lightgrey)](https://gocover.io/github.com/dgraph-io/ristretto) +[![Go Doc](https://img.shields.io/badge/godoc-reference-blue.svg)](http://godoc.org/github.com/dgraph-io/ristretto) +[![Go Report Card](https://img.shields.io/badge/go%20report-A%2B-brightgreen)](https://goreportcard.com/report/github.com/dgraph-io/ristretto) +[![Coverage](https://img.shields.io/badge/coverage-100%25-brightgreen)](https://gocover.io/github.com/dgraph-io/ristretto) +![Tests](https://github.com/dgraph-io/ristretto/workflows/tests/badge.svg) + Ristretto is a fast, concurrent cache library built with a focus on performance and correctness. diff --git a/cache.go b/cache.go index 647cb9e3..a2147cbc 100644 --- a/cache.go +++ b/cache.go @@ -20,9 +20,7 @@ package ristretto import ( - "bytes" "errors" - "fmt" "sync/atomic" "github.com/dgraph-io/ristretto/z" @@ -133,12 +131,9 @@ func NewCache(config *Config) (*Cache, error) { } policy := newPolicy(config.NumCounters, config.MaxCost) cache := &Cache{ - store: newStore(), - policy: policy, - getBuf: newRingBuffer(ringLossy, &ringConfig{ - Consumer: policy, - Capacity: config.BufferItems, - }), + store: newStore(), + policy: policy, + getBuf: newRingBuffer(policy, config.BufferItems), setBuf: make(chan *item, setBufSize), onEvict: config.OnEvict, keyToHash: config.KeyToHash, @@ -285,17 +280,75 @@ func (c *Cache) processItems() { } } +// collectMetrics just creates a new *metrics instance and adds the pointers +// to the cache and policy instances. func (c *Cache) collectMetrics() { - c.stats = newMetrics() - c.policy.CollectMetrics(c.stats) + stats := newMetrics() + c.stats = stats + c.policy.CollectMetrics(stats) } // Metrics returns statistics about cache performance. -func (c *Cache) Metrics() *metrics { +func (c *Cache) Metrics() *Metrics { if c == nil { return nil } - return c.stats + return exportMetrics(c.stats) +} + +// exportMetrics converts an internal metrics struct into a friendlier Metrics +// struct. +func exportMetrics(stats *metrics) *Metrics { + return &Metrics{ + Hits: stats.Get(hit), + Misses: stats.Get(miss), + Ratio: stats.Ratio(), + KeysAdded: stats.Get(keyAdd), + KeysUpdated: stats.Get(keyUpdate), + KeysEvicted: stats.Get(keyEvict), + CostAdded: stats.Get(costAdd), + CostEvicted: stats.Get(costEvict), + SetsDropped: stats.Get(dropSets), + SetsRejected: stats.Get(rejectSets), + GetsDropped: stats.Get(dropGets), + GetsKept: stats.Get(keepGets), + } +} + +// Metrics is a snapshot of performance statistics for the lifetime of a cache +// instance. +type Metrics struct { + // Hits is the number of Get calls where a value was found for the + // corresponding key. + Hits uint64 `json:"hits"` + // Misses is the number of Get calls where a value was not found for the + // corresponding key. + Misses uint64 `json:"misses"` + // Ratio is the number of Hits over all accesses (Hits + Misses). This is + // the percentage of successful Get calls. + Ratio float64 `json:"ratio"` + // KeysAdded is the total number of Set calls where a new key-value item was + // added. + KeysAdded uint64 `json:"keysAdded"` + // KeysUpdated is the total number of Set calls where the value was updated. + KeysUpdated uint64 `json:"keysUpdated"` + // KeysEvicted is the total number of keys evicted. + KeysEvicted uint64 `json:"keysEvicted"` + // CostAdded is the sum of all costs that have been added (successful Set + // calls). + CostAdded uint64 `json:"costAdded"` + // CostEvicted is the sum of all costs that have been evicted. + CostEvicted uint64 `json:"costEvicted"` + // SetsDropped is the number of Set calls that don't make it into internal + // buffers (due to contention or some other reason). + SetsDropped uint64 `json:"setsDropped"` + // SetsRejected is the number of Set calls rejected by the policy (TinyLFU). + SetsRejected uint64 `json:"setsRejected"` + // GetsDropped is the number of Get counter increments that are dropped + // internally. + GetsDropped uint64 `json:"getsDropped"` + // GetsKept is the number of Get counter increments that are kept. + GetsKept uint64 `json:"getsKept"` } type metricType int @@ -304,60 +357,27 @@ const ( // The following 2 keep track of hits and misses. hit = iota miss - // The following 3 keep track of number of keys added, updated and evicted. keyAdd keyUpdate keyEvict - // The following 2 keep track of cost of keys added and evicted. costAdd costEvict - // The following keep track of how many sets were dropped or rejected later. dropSets rejectSets - - // The following 2 keep track of how many gets were kept and dropped on the floor. + // The following 2 keep track of how many gets were kept and dropped on the + // floor. dropGets keepGets - // This should be the final enum. Other enums should be set before this. doNotUse ) -func stringFor(t metricType) string { - switch t { - case hit: - return "hit" - case miss: - return "miss" - case keyAdd: - return "keys-added" - case keyUpdate: - return "keys-updated" - case keyEvict: - return "keys-evicted" - case costAdd: - return "cost-added" - case costEvict: - return "cost-evicted" - case dropSets: - return "sets-dropped" - case rejectSets: - return "sets-rejected" // by policy. - case dropGets: - return "gets-dropped" - case keepGets: - return "gets-kept" - default: - return "unidentified" - } -} - -// metrics is the struct for hit ratio statistics. Note that there is some -// cost to maintaining the counters, so it's best to wrap Policies via the -// Recorder type when hit ratio analysis is needed. +// metrics is the struct for hit ratio statistics. Padding is used to avoid +// false sharing in order to minimize the performance cost for those who track +// metrics outside of testing scenarios. type metrics struct { all [doNotUse][]*uint64 } @@ -407,17 +427,3 @@ func (p *metrics) Ratio() float64 { } return float64(hits) / float64(hits+misses) } - -func (p *metrics) String() string { - if p == nil { - return "" - } - var buf bytes.Buffer - for i := 0; i < doNotUse; i++ { - t := metricType(i) - fmt.Fprintf(&buf, "%s: %d ", stringFor(t), p.Get(t)) - } - fmt.Fprintf(&buf, "gets-total: %d ", p.Get(hit)+p.Get(miss)) - fmt.Fprintf(&buf, "hit-ratio: %.2f", p.Ratio()) - return buf.String() -} diff --git a/cache_test.go b/cache_test.go index fe1c2c94..160149e8 100644 --- a/cache_test.go +++ b/cache_test.go @@ -1,519 +1,296 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package ristretto import ( - "container/heap" - "fmt" - "math/rand" - "runtime" "sync" - "sync/atomic" "testing" "time" - - "github.com/dgraph-io/ristretto/sim" - "github.com/dgraph-io/ristretto/z" ) -// TestCache is used to pass instances of Ristretto and Clairvoyant around and -// compare their performance. -type TestCache interface { - Get(interface{}) (interface{}, bool) - Set(interface{}, interface{}, int64) bool - Metrics() *metrics +func TestCache(t *testing.T) { + if _, err := NewCache(&Config{ + NumCounters: 0, + }); err == nil { + t.Fatal("numCounters can't be 0") + } + if _, err := NewCache(&Config{ + NumCounters: 100, + MaxCost: 0, + }); err == nil { + t.Fatal("maxCost can't be 0") + } + if _, err := NewCache(&Config{ + NumCounters: 100, + MaxCost: 10, + BufferItems: 0, + }); err == nil { + t.Fatal("bufferItems can't be 0") + } + if c, err := NewCache(&Config{ + NumCounters: 100, + MaxCost: 10, + BufferItems: 64, + Metrics: true, + }); c == nil || err != nil { + t.Fatal("config should be good") + } } -// capacity is the cache capacity to be used across all tests and benchmarks. -const capacity = 1000 - -// newCache should be used for all Ristretto instances in local tests. -func newCache(metrics bool) *Cache { - cache, err := NewCache(&Config{ - NumCounters: capacity * 10, - MaxCost: capacity, +func TestCacheProcessItems(t *testing.T) { + m := &sync.Mutex{} + evicted := make(map[uint64]struct{}) + c, err := NewCache(&Config{ + NumCounters: 100, + MaxCost: 10, BufferItems: 64, - Metrics: metrics, + Cost: func(value interface{}) int64 { + return int64(value.(int)) + }, + OnEvict: func(key uint64, value interface{}, cost int64) { + m.Lock() + defer m.Unlock() + evicted[key] = struct{}{} + }, }) if err != nil { panic(err) } - return cache -} - -// newBenchmark should be used for all local benchmarks to ensure consistency -// across comparisons. -func newBenchmark(bencher func(uint64)) func(b *testing.B) { - return func(b *testing.B) { - b.SetParallelism(1) - b.SetBytes(1) - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for i := uint64(0); pb.Next(); i++ { - bencher(i) - } - }) - } -} - -// BenchmarkCacheGetOne Gets the same key-value item over and over. -func BenchmarkCacheGetOne(b *testing.B) { - cache := newCache(false) - cache.Set(1, nil, 1) - newBenchmark(func(i uint64) { cache.Get(1) })(b) -} - -// BenchmarkCacheSetOne Sets the same key-value item over and over. -func BenchmarkCacheSetOne(b *testing.B) { - cache := newCache(false) - newBenchmark(func(i uint64) { cache.Set(1, nil, 1) })(b) -} - -// BenchmarkCacheSetUni Sets keys incrementing by 1. -func BenchmarkCacheSetUni(b *testing.B) { - cache := newCache(false) - newBenchmark(func(i uint64) { cache.Set(i, nil, 1) })(b) -} - -// newRatioTest simulates a workload for a TestCache so you can just run the -// returned test and call cache.metrics() to get a basic idea of performance. -func newRatioTest(cache TestCache) func(t *testing.T) { - return func(t *testing.T) { - keys := sim.NewZipfian(1.0001, 1, capacity*100) - for i := 0; i < capacity*1000; i++ { - key, err := keys() - if err != nil { - t.Fatal(err) - } - if _, ok := cache.Get(key); !ok { - cache.Set(key, nil, 1) - } + c.setBuf <- &item{flag: itemNew, key: 1, value: 1, cost: 0} + time.Sleep(time.Millisecond) + if !c.policy.Has(1) || c.policy.Cost(1) != 1 { + t.Fatal("cache processItems didn't add new item") + } + c.setBuf <- &item{flag: itemUpdate, key: 1, value: 2, cost: 0} + time.Sleep(time.Millisecond) + if c.policy.Cost(1) != 2 { + t.Fatal("cache processItems didn't update item cost") + } + c.setBuf <- &item{flag: itemDelete, key: 1} + time.Sleep(time.Millisecond) + if val, ok := c.store.Get(1); val != nil || ok { + t.Fatal("cache processItems didn't delete item") + } + if c.policy.Has(1) { + t.Fatal("cache processItems didn't delete item") + } + c.setBuf <- &item{flag: itemNew, key: 2, value: 2, cost: 3} + c.setBuf <- &item{flag: itemNew, key: 3, value: 3, cost: 3} + c.setBuf <- &item{flag: itemNew, key: 4, value: 3, cost: 3} + c.setBuf <- &item{flag: itemNew, key: 5, value: 3, cost: 5} + time.Sleep(time.Millisecond) + m.Lock() + if len(evicted) == 0 { + m.Unlock() + t.Fatal("cache processItems not evicting or calling OnEvict") + } + m.Unlock() + defer func() { + if r := recover(); r == nil { + t.Fatal("cache processItems didn't stop") } - } + }() + c.Close() + c.setBuf <- &item{flag: itemNew} } -func TestCacheClear(t *testing.T) { - cache := newCache(true) - for i := 0; i < capacity; i++ { - cache.Set(i, i, 1) - } - cache.Clear() - if len(cache.setBuf) != 0 { - t.Fatal("setBuf not cleared") - } - for i := 0; i < capacity; i++ { - if _, ok := cache.Get(i); ok { - t.Fatal("clear operation failed") - } +func TestCacheGet(t *testing.T) { + c, err := NewCache(&Config{ + NumCounters: 100, + MaxCost: 10, + BufferItems: 64, + Metrics: true, + }) + if err != nil { + panic(err) } - // verify that we can still Set/Get items with the new buffers - for i := 0; i < capacity; i++ { - cache.Set(i, i, 1) + c.store.Set(1, 1) + if val, ok := c.Get(1); val == nil || !ok { + t.Fatal("get should be successful") } - time.Sleep(time.Second / 100) - for i := 0; i < capacity; i++ { - if _, found := cache.Get(i); !found { - t.Fatal("value should exist") - } + if val, ok := c.Get(2); val != nil || ok { + t.Fatal("get should not be successful") } - // 0.5 and not 1.0 because we tried Getting each item twice - if cache.Metrics().Ratio() != 0.5 { - t.Fatal("incorrect hit ratio") + if c.stats.Ratio() != 0.5 { + t.Fatal("get should record metrics") } -} - -func TestCacheSetDel(t *testing.T) { - cache := newCache(true) - cache.Set(1, 1, 1) - cache.Del(1) - time.Sleep(time.Second / 100) - if _, found := cache.Get(1); found { - t.Fatal("value shouldn't exist") + c = nil + if val, ok := c.Get(0); val != nil || ok { + t.Fatal("get should not be successful with nil cache") } } -func TestCacheCoster(t *testing.T) { - costRuns := uint64(0) - cache, err := NewCache(&Config{ - NumCounters: 1000, - MaxCost: 500, +func TestCacheSet(t *testing.T) { + c, err := NewCache(&Config{ + NumCounters: 100, + MaxCost: 10, BufferItems: 64, - Cost: func(value interface{}) int64 { - atomic.AddUint64(&costRuns, 1) - return 5 - }, + Metrics: true, }) if err != nil { panic(err) } - for i := 0; i < 100; i++ { - cache.Set(i, i, 0) - } - time.Sleep(time.Second / 100) - for i := 0; i < 100; i++ { - if cache.policy.Cost(z.KeyToHash(i)) != 5 { - t.Fatal("coster not being ran") + if c.Set(1, 1, 1) { + time.Sleep(time.Millisecond) + if val, ok := c.Get(1); val == nil || val.(int) != 1 || !ok { + t.Fatal("set/get returned wrong value") + } + } else { + if val, ok := c.Get(1); val != nil || ok { + t.Fatal("set was dropped but value still added") } } - if costRuns != 100 { - t.Fatal("coster not being ran") + c.Set(1, 2, 2) + if val, ok := c.store.Get(1); val == nil || val.(int) != 2 || !ok { + t.Fatal("set/update was unsuccessful") } -} - -// TestCacheUpdate verifies that a Set call on an existing key immediately -// updates the value and cost for that key without using/polluting the Set -// buffer(s). -func TestCacheUpdate(t *testing.T) { - cache := newCache(true) - cache.Set(1, 1, 1) - // wait for new-item Set to go through - time.Sleep(time.Second / 100) - // do 100 updates - for i := 0; i < 100; i++ { - // update the same key (1) with incrementing value and cost, so we can - // verify that they are immediately updated and not going through - // channels - cache.Set(1, i, int64(i)) - if val, ok := cache.Get(1); !ok || val.(int) != i { - t.Fatal("keyUpdate value inconsistent") - } + c.stop <- struct{}{} + for i := 0; i < setBufSize; i++ { + c.setBuf <- &item{itemUpdate, 1, 1, 1} } - // wait for keyUpdates to go through - time.Sleep(time.Second / 100) - if cache.Metrics().Get(keyUpdate) == 0 { - t.Fatal("keyUpdates not being processed") + if c.Set(2, 2, 1) { + t.Fatal("set should be dropped with full setBuf") + } + if c.stats.Get(dropSets) != 1 { + t.Fatal("set should track dropSets") + } + close(c.setBuf) + close(c.stop) + c = nil + if c.Set(1, 1, 1) { + t.Fatal("set shouldn't be successful with nil cache") } } -func TestCacheOnEvict(t *testing.T) { - mu := &sync.Mutex{} - evictions := make(map[uint64]int) - cache, err := NewCache(&Config{ - NumCounters: 1000, - MaxCost: 100, - BufferItems: 1, - OnEvict: func(key uint64, value interface{}, cost int64) { - mu.Lock() - defer mu.Unlock() - evictions[key] = value.(int) - }, +func TestCacheDel(t *testing.T) { + c, err := NewCache(&Config{ + NumCounters: 100, + MaxCost: 10, + BufferItems: 64, }) if err != nil { panic(err) } - for i := 0; i < 256; i++ { - cache.Set(i, i, 1) - } - time.Sleep(time.Second / 100) - mu.Lock() - defer mu.Unlock() - if len(evictions) != 156 { - t.Fatal("onEvict not being called") + c.Set(1, 1, 1) + c.Del(1) + time.Sleep(time.Millisecond) + if val, ok := c.Get(1); val != nil || ok { + t.Fatal("del didn't delete") } - for k, v := range evictions { - if k != uint64(v) { - t.Fatal("onEvict key-val mismatch") + c = nil + defer func() { + if r := recover(); r != nil { + t.Fatal("del panic with nil cache") } - } + }() + c.Del(1) } -func TestCacheKeyToHash(t *testing.T) { - cache, err := NewCache(&Config{ - NumCounters: 1000, - MaxCost: 100, - BufferItems: 1, - KeyToHash: func(key interface{}) uint64 { - i, ok := key.(int) - if !ok { - panic("failed to type assert") - } - return uint64(i + 2) - }, +func TestCacheClear(t *testing.T) { + c, err := NewCache(&Config{ + NumCounters: 100, + MaxCost: 10, + BufferItems: 64, + Metrics: true, }) if err != nil { panic(err) } for i := 0; i < 10; i++ { - if uint64(i+2) != cache.keyToHash(i) { - t.Fatal("keyToHash hash mismatch") - } - } -} - -// TestCacheRatios gives us a rough idea of the hit ratio relative to the -// theoretical optimum. Useful for quickly seeing the effects of changes. -func TestCacheRatios(t *testing.T) { - cache := newCache(true) - optimal := NewClairvoyant(capacity) - newRatioTest(cache)(t) - newRatioTest(optimal)(t) - t.Logf("ristretto: %.2f\n", cache.Metrics().Ratio()) - t.Logf("- optimal: %.2f\n", optimal.Metrics().Ratio()) -} - -var newCacheInvalidConfigTests = []struct { - conf Config - desc string -}{ - { - conf: Config{ - NumCounters: 0, - MaxCost: 1, - BufferItems: 1, - }, - desc: "NumCounters is 0", - }, - { - conf: Config{ - NumCounters: 1, - MaxCost: 0, - BufferItems: 1, - }, - desc: "MaxCost is 0", - }, - { - conf: Config{ - NumCounters: 1, - MaxCost: 1, - BufferItems: 0, - }, - desc: "BufferItems is 0", - }, -} - -func TestNewCacheInvalidConfig(t *testing.T) { - for _, tc := range newCacheInvalidConfigTests { - _, err := NewCache(&tc.conf) - - if err == nil { - t.Fatalf("%s: NewCache should return an error", tc.desc) - } + c.Set(i, i, 1) } - -} - -func TestCacheNil(t *testing.T) { - var cache *Cache - - r := cache.Set("key", "value", 1) - if r != false { - t.Fatal("Calling Set on nil Cache should return false") + time.Sleep(time.Millisecond) + if c.stats.Get(keyAdd) != 10 { + t.Fatal("range of sets not being processed") } - - _, r = cache.Get("key") - if r != false { - t.Fatal("Calling Get on nil Cache should return false") + c.Clear() + if c.stats.Get(keyAdd) != 0 { + t.Fatal("clear didn't reset metrics") } -} - -func TestCacheDel(t *testing.T) { - cache := newCache(true) - // fill the cache with data - for key := 0; key < capacity; key++ { - cache.Set(key, key, 1) - } - // wait for the Sets to be processed so that all values are in the cache - // before we begin Gets, otherwise the hit ratio would be bad - time.Sleep(time.Second / 100) - - wg := &sync.WaitGroup{} - // launch goroutines to concurrently Del keys - for b := 0; b < capacity/100; b++ { - wg.Add(1) - go func(b int) { - for i := 100 * b; i < 100*b+100; i++ { - cache.Del(i) - } - wg.Done() - }(b) - } - wg.Wait() - - // wait for Dels to be processed (they pass through the same buffer as Set) - time.Sleep(time.Second / 100) - - for key := 0; key < capacity; key++ { - if _, ok := cache.Get(key); ok { - t.Fatalf("cache key %d should not be exist\n", key) + for i := 0; i < 10; i++ { + if val, ok := c.Get(i); val != nil || ok { + t.Fatal("clear didn't delete values") } } - - if ratio := cache.Metrics().Ratio(); ratio != 0.0 { - t.Fatalf("expected 0.00 but got %.2f\n", ratio) - } } -func TestCacheSetGet(t *testing.T) { - cache := newCache(true) - // fill the cache with data - for key := 0; key < capacity; key++ { - cache.Set(key, key, 1) - } - // wait for the Sets to be processed so that all values are in the cache - // before we begin Gets, otherwise the hit ratio would be bad - time.Sleep(time.Second / 100) - wg := &sync.WaitGroup{} - // launch goroutines to concurrently Get random keys - - var err error - for r := 0; r < 8; r++ { - wg.Add(1) - go func() { - r := rand.New(rand.NewSource(time.Now().UnixNano())) - // it's not too important that we iterate through the whole capacity - // here, but we want all the goroutines to be Getting in parallel, - // so it should iterate long enough that it will continue until the - // other goroutines are done spinning up - for i := 0; i < capacity; i++ { - key := r.Int() % capacity - if val, ok := cache.Get(key); ok { - if val.(int) != key { - err = fmt.Errorf("expected %d but got %d", key, val.(int)) - break - } - } - } - wg.Done() - }() - } - wg.Wait() +func TestCacheMetrics(t *testing.T) { + c, err := NewCache(&Config{ + NumCounters: 100, + MaxCost: 10, + BufferItems: 64, + Metrics: true, + }) if err != nil { - t.Fatal(err) + panic(err) } - - if ratio := cache.Metrics().Ratio(); ratio != 1.0 { - t.Fatalf("expected 1.00 but got %.2f\n", ratio) + for i := 0; i < 10; i++ { + c.Set(i, i, 1) } -} - -// TestCacheSetNil makes sure nil values are working properly. -func TestCacheSetNil(t *testing.T) { - cache := newCache(false) - cache.Set(1, nil, 1) - // must wait for the set buffer - time.Sleep(time.Second / 1000) - if value, ok := cache.Get(1); !ok || value != nil { - t.Fatal("cache value should exist and be nil") + time.Sleep(time.Millisecond) + m := c.Metrics() + if m.KeysAdded != 10 { + t.Fatal("metrics exporting incorrect fields") } -} - -// TestCacheSetDrops simulates a period of high contention and reports the -// percentage of Sets that are dropped. For most use cases, it would be rare to -// have more than 4 goroutines calling Set in parallel. Nevertheless, this is a -// useful stress test. -func TestCacheSetDrops(t *testing.T) { - for goroutines := 1; goroutines <= 16; goroutines++ { - n, size := goroutines, capacity*10 - sample := uint64(n * size) - cache := newCache(true) - keys := sim.Collection(sim.NewUniform(sample), sample) - start, finish := &sync.WaitGroup{}, &sync.WaitGroup{} - start.Add(n) - finish.Add(n) - for i := 0; i < n; i++ { - go func(i int) { - start.Done() - // wait for all goroutines to be ready - start.Wait() - for j := i * size; j < (i*size)+size; j++ { - cache.Set(keys[j], 0, 1) - } - finish.Done() - }(i) - } - finish.Wait() - dropped := cache.Metrics().Get(dropSets) - t.Logf("%d goroutines: %.2f%% dropped \n", - goroutines, float64(float64(dropped)/float64(sample))*100) - runtime.GC() + c = nil + if c.Metrics() != nil { + t.Fatal("metrics exporting non-nil with nil cache") } } -// Clairvoyant is a mock cache providing us with optimal hit ratios to compare -// with Ristretto's. It looks ahead and evicts the absolute least valuable item, -// which we try to approximate in a real cache. -type Clairvoyant struct { - capacity uint64 - hits map[uint64]uint64 - access []uint64 +func TestMetrics(t *testing.T) { + newMetrics() } -func NewClairvoyant(capacity uint64) *Clairvoyant { - return &Clairvoyant{ - capacity: capacity, - hits: make(map[uint64]uint64), - access: make([]uint64, 0), +func TestMetricsAddGet(t *testing.T) { + m := newMetrics() + m.Add(hit, 1, 1) + m.Add(hit, 2, 2) + m.Add(hit, 3, 3) + if m.Get(hit) != 6 { + t.Fatal("add/get error") } -} - -// Get just records the cache access so that we can later take this event into -// consideration when calculating the absolute least valuable item to evict. -func (c *Clairvoyant) Get(key interface{}) (interface{}, bool) { - c.hits[key.(uint64)]++ - c.access = append(c.access, key.(uint64)) - return nil, false -} - -// Set isn't important because it is only called after a Get (in the case of our -// hit ratio benchmarks, at least). -func (c *Clairvoyant) Set(key, value interface{}, cost int64) bool { - return false -} - -func (c *Clairvoyant) Metrics() *metrics { - stat := newMetrics() - look := make(map[uint64]struct{}, c.capacity) - data := &clairvoyantHeap{} - heap.Init(data) - for _, key := range c.access { - if _, has := look[key]; has { - stat.Add(hit, 0, 1) - continue - } - if uint64(data.Len()) >= c.capacity { - victim := heap.Pop(data) - delete(look, victim.(*clairvoyantItem).key) - } - stat.Add(miss, 0, 1) - look[key] = struct{}{} - heap.Push(data, &clairvoyantItem{key, c.hits[key]}) + m = nil + m.Add(hit, 1, 1) + if m.Get(hit) != 0 { + t.Fatal("get with nil struct should return 0") } - return stat -} - -type clairvoyantItem struct { - key uint64 - hits uint64 } -type clairvoyantHeap []*clairvoyantItem - -func (h clairvoyantHeap) Len() int { return len(h) } -func (h clairvoyantHeap) Less(i, j int) bool { return h[i].hits < h[j].hits } -func (h clairvoyantHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } - -func (h *clairvoyantHeap) Push(x interface{}) { - *h = append(*h, x.(*clairvoyantItem)) +func TestMetricsRatio(t *testing.T) { + m := newMetrics() + if m.Ratio() != 0 { + t.Fatal("ratio with no hits or misses should be 0") + } + m.Add(hit, 1, 1) + m.Add(hit, 2, 2) + m.Add(miss, 1, 1) + m.Add(miss, 2, 2) + if m.Ratio() != 0.5 { + t.Fatal("ratio incorrect") + } + m = nil + if m.Ratio() != 0.0 { + t.Fatal("ratio with a nil struct should return 0") + } } -func (h *clairvoyantHeap) Pop() interface{} { - old := *h - n := len(old) - x := old[n-1] - *h = old[0 : n-1] - return x +func TestMetricsExport(t *testing.T) { + m := newMetrics() + m.Add(hit, 1, 1) + m.Add(miss, 1, 1) + m.Add(keyAdd, 1, 1) + m.Add(keyUpdate, 1, 1) + m.Add(keyEvict, 1, 1) + m.Add(costAdd, 1, 1) + m.Add(costEvict, 1, 1) + m.Add(dropSets, 1, 1) + m.Add(rejectSets, 1, 1) + m.Add(dropGets, 1, 1) + m.Add(keepGets, 1, 1) + M := exportMetrics(m) + if M.Hits != 1 || M.Misses != 1 || M.Ratio != 0.5 || M.KeysAdded != 1 || + M.KeysUpdated != 1 || M.KeysEvicted != 1 || M.CostAdded != 1 || + M.CostEvicted != 1 || M.SetsDropped != 1 || M.SetsRejected != 1 || + M.GetsDropped != 1 || M.GetsKept != 1 { + t.Fatal("exportMetrics wrong value(s)") + } } diff --git a/policy.go b/policy.go index 1188ab61..4845174d 100644 --- a/policy.go +++ b/policy.go @@ -17,7 +17,6 @@ package ristretto import ( - "container/list" "math" "sync" @@ -31,6 +30,9 @@ const ( ) // policy is the interface encapsulating eviction/admission behavior. +// +// TODO: remove this interface and just rename defaultPolicy to policy, as we +// are probably only going to use/implement/maintain one policy. type policy interface { ringConsumer // Add attempts to Add the key-cost pair to the Policy. It returns a slice @@ -56,18 +58,9 @@ type policy interface { } func newPolicy(numCounters, maxCost int64) policy { - p := &defaultPolicy{ - admit: newTinyLFU(numCounters), - evict: newSampledLFU(maxCost), - itemsCh: make(chan []uint64, 3), - stop: make(chan struct{}), - } - go p.processItems() - return p + return newDefaultPolicy(numCounters, maxCost) } -// defaultPolicy is the default defaultPolicy, which is currently TinyLFU -// admission with sampledLFU eviction. type defaultPolicy struct { sync.Mutex admit *tinyLFU @@ -77,6 +70,17 @@ type defaultPolicy struct { stats *metrics } +func newDefaultPolicy(numCounters, maxCost int64) *defaultPolicy { + p := &defaultPolicy{ + admit: newTinyLFU(numCounters), + evict: newSampledLFU(maxCost), + itemsCh: make(chan []uint64, 3), + stop: make(chan struct{}), + } + go p.processItems() + return p +} + func (p *defaultPolicy) CollectMetrics(stats *metrics) { p.stats = stats p.evict.stats = stats @@ -145,7 +149,7 @@ func (p *defaultPolicy) Add(key uint64, cost int64) ([]*item, bool) { sample := make([]*policyPair, 0, lfuSample) // as items are evicted they will be appended to victims victims := make([]*item, 0) - // Delete victims until there's enough space or a minKey is found that has + // delete victims until there's enough space or a minKey is found that has // more hits than incoming item. for ; room < 0; room = p.evict.roomLeft(cost) { // fill up empty slots in sample @@ -158,7 +162,7 @@ func (p *defaultPolicy) Add(key uint64, cost int64) ([]*item, bool) { minKey, minHits, minId, minCost = pair.key, hits, i, pair.cost } } - // If the incoming item isn't worth keeping in the policy, reject. + // if the incoming item isn't worth keeping in the policy, reject. if incHits < minHits { p.stats.Add(rejectSets, key, 1) return victims, false @@ -352,127 +356,3 @@ func (p *tinyLFU) clear() { p.door.Clear() p.freq.Clear() } - -// lruPolicy is different than the default policy in that it uses exact LRU -// eviction rather than Sampled LFU eviction, which may be useful for certain -// workloads (ARC-OLTP for example; LRU heavy workloads). -// -// TODO: - cost based eviction (multiple evictions for one new item, etc.) -// - sampled LRU -type lruPolicy struct { - sync.Mutex - admit *tinyLFU - ptrs map[uint64]*lruItem - vals *list.List - maxCost int64 - room int64 -} - -type lruItem struct { - ptr *list.Element - key uint64 - cost int64 -} - -func newLRUPolicy(numCounters, maxCost int64) policy { - return &lruPolicy{ - admit: newTinyLFU(numCounters), - ptrs: make(map[uint64]*lruItem, maxCost), - vals: list.New(), - room: maxCost, - maxCost: maxCost, - } -} - -func (p *lruPolicy) Push(keys []uint64) bool { - if len(keys) == 0 { - return true - } - p.Lock() - defer p.Unlock() - for _, key := range keys { - // increment tinylfu counter - p.admit.Increment(key) - // move list item to front - if val, ok := p.ptrs[key]; ok { - // move accessed val to MRU position - p.vals.MoveToFront(val.ptr) - } - } - return true -} - -func (p *lruPolicy) Add(key uint64, cost int64) ([]*item, bool) { - p.Lock() - defer p.Unlock() - if cost > p.maxCost { - return nil, false - } - if val, has := p.ptrs[key]; has { - p.vals.MoveToFront(val.ptr) - return nil, true - } - victims := make([]*item, 0) - incHits := p.admit.Estimate(key) - for p.room < 0 { - lru := p.vals.Back() - victim := lru.Value.(*lruItem) - if incHits < p.admit.Estimate(victim.key) { - return victims, false - } - // delete victim from metadata - p.vals.Remove(victim.ptr) - delete(p.ptrs, victim.key) - victims = append(victims, &item{ - key: victim.key, - cost: victim.cost, - }) - // adjust room - p.room += victim.cost - } - newItem := &lruItem{key: key, cost: cost} - newItem.ptr = p.vals.PushFront(newItem) - p.ptrs[key] = newItem - p.room -= cost - return victims, true -} - -func (p *lruPolicy) Has(key uint64) bool { - p.Lock() - defer p.Unlock() - _, has := p.ptrs[key] - return has -} - -func (p *lruPolicy) Del(key uint64) { - p.Lock() - defer p.Unlock() - if val, ok := p.ptrs[key]; ok { - p.vals.Remove(val.ptr) - delete(p.ptrs, key) - } -} - -func (p *lruPolicy) Cap() int64 { - p.Lock() - defer p.Unlock() - return int64(p.vals.Len()) -} - -func (p *lruPolicy) Close() {} - -// TODO -func (p *lruPolicy) Update(key uint64, cost int64) { -} - -// TODO -func (p *lruPolicy) Cost(key uint64) int64 { - return -1 -} - -// TODO -func (p *lruPolicy) CollectMetrics(stats *metrics) { -} - -// TODO -func (p *lruPolicy) Clear() {} diff --git a/policy_test.go b/policy_test.go index 3036c9f4..5a2a77c1 100644 --- a/policy_test.go +++ b/policy_test.go @@ -1,76 +1,300 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package ristretto import ( "testing" + "time" ) -func GeneratePolicyTest(p func(int64, int64) policy) func(*testing.T) { - return func(t *testing.T) { - t.Run("uniform-push", func(t *testing.T) { - policy := p(1024, 1024) - values := make([]uint64, 1024) - for i := range values { - values[i] = uint64(i) - } - policy.Add(0, 1) - policy.Push(values) - if !policy.Has(0) || policy.Has(999999) { - t.Fatal("add/push error") - } - }) - t.Run("uniform-add", func(t *testing.T) { - policy := p(1024, 1024) - for i := int64(0); i < 1024; i++ { - policy.Add(uint64(i), 1) - } - if vics, added := policy.Add(999999, 1); vics == nil || !added { - t.Fatal("add/eviction error") - } - }) - t.Run("variable-push", func(t *testing.T) { - policy := p(1024, 1024*4) - values := make([]uint64, 1024) - for i := range values { - values[i] = uint64(i) - } - policy.Add(0, 1) - policy.Push(values) - if !policy.Has(0) || policy.Has(999999) { - t.Fatal("add/push error") - } - }) - t.Run("variable-add", func(t *testing.T) { - policy := p(1024, 1024*4) - for i := int64(0); i < 1024; i++ { - policy.Add(uint64(i), 4) - } - if vics, added := policy.Add(999999, 1); vics == nil || !added { - t.Fatal("add/eviction error") - } - }) - } -} - -func TestLFUPolicy(t *testing.T) { - GeneratePolicyTest(newPolicy)(t) -} - -func TestLRUPolicy(t *testing.T) { - GeneratePolicyTest(newLRUPolicy)(t) +func TestPolicy(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Fatal("newPolicy failed") + } + }() + newPolicy(100, 10) +} + +func TestPolicyMetrics(t *testing.T) { + p := newDefaultPolicy(100, 10) + p.CollectMetrics(newMetrics()) + if p.stats == nil || p.evict.stats == nil { + t.Fatal("policy metrics initialization error") + } +} + +func TestPolicyProcessItems(t *testing.T) { + p := newDefaultPolicy(100, 10) + p.itemsCh <- []uint64{1, 2, 2} + time.Sleep(time.Millisecond) + p.Lock() + if p.admit.Estimate(2) != 2 || p.admit.Estimate(1) != 1 { + p.Unlock() + t.Fatal("policy processItems not pushing to tinylfu counters") + } + p.Unlock() + p.stop <- struct{}{} + p.itemsCh <- []uint64{3, 3, 3} + time.Sleep(time.Millisecond) + p.Lock() + if p.admit.Estimate(3) != 0 { + p.Unlock() + t.Fatal("policy processItems not stopping") + } + p.Unlock() +} + +func TestPolicyPush(t *testing.T) { + p := newDefaultPolicy(100, 10) + if !p.Push([]uint64{}) { + t.Fatal("push empty slice should be good") + } + keepCount := 0 + for i := 0; i < 10; i++ { + if p.Push([]uint64{1, 2, 3, 4, 5}) { + keepCount++ + } + } + if keepCount == 0 { + t.Fatal("push dropped everything") + } +} + +func TestPolicyAdd(t *testing.T) { + p := newDefaultPolicy(1000, 100) + if victims, added := p.Add(1, 101); victims != nil || added { + t.Fatal("can't add an item bigger than entire cache") + } + p.Lock() + p.evict.add(1, 1) + p.admit.Increment(1) + p.admit.Increment(2) + p.admit.Increment(3) + p.Unlock() + if victims, added := p.Add(1, 1); victims != nil || !added { + t.Fatal("item should already exist") + } + if victims, added := p.Add(2, 20); victims != nil || !added { + t.Fatal("item should be added with no eviction") + } + if victims, added := p.Add(3, 90); victims == nil || !added { + t.Fatal("item should be added with eviction") + } + if victims, added := p.Add(4, 20); victims == nil || added { + t.Fatal("item should not be added") + } +} + +func TestPolicyHas(t *testing.T) { + p := newDefaultPolicy(100, 10) + p.Add(1, 1) + if !p.Has(1) { + t.Fatal("policy should have key") + } + if p.Has(2) { + t.Fatal("policy shouldn't have key") + } +} + +func TestPolicyDel(t *testing.T) { + p := newDefaultPolicy(100, 10) + p.Add(1, 1) + p.Del(1) + p.Del(2) + if p.Has(1) { + t.Fatal("del didn't delete") + } + if p.Has(2) { + t.Fatal("policy shouldn't have key") + } +} + +func TestPolicyCap(t *testing.T) { + p := newDefaultPolicy(100, 10) + p.Add(1, 1) + if p.Cap() != 9 { + t.Fatal("cap returned wrong value") + } +} + +func TestPolicyUpdate(t *testing.T) { + p := newDefaultPolicy(100, 10) + p.Add(1, 1) + p.Update(1, 2) + p.Lock() + if p.evict.keyCosts[1] != 2 { + p.Unlock() + t.Fatal("update failed") + } + p.Unlock() +} + +func TestPolicyCost(t *testing.T) { + p := newDefaultPolicy(100, 10) + p.Add(1, 2) + if p.Cost(1) != 2 { + t.Fatal("cost for existing key returned wrong value") + } + if p.Cost(2) != -1 { + t.Fatal("cost for missing key returned wrong value") + } +} + +func TestPolicyClear(t *testing.T) { + p := newDefaultPolicy(100, 10) + p.Add(1, 1) + p.Add(2, 2) + p.Add(3, 3) + p.Clear() + if p.Cap() != 10 || p.Has(1) || p.Has(2) || p.Has(3) { + t.Fatal("clear didn't clear properly") + } +} + +func TestPolicyClose(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("close didn't close channels") + } + }() + p := newDefaultPolicy(100, 10) + p.Add(1, 1) + p.Close() + p.itemsCh <- []uint64{1} +} + +func TestSampledLFUAdd(t *testing.T) { + e := newSampledLFU(4) + e.add(1, 1) + e.add(2, 2) + e.add(3, 1) + if e.used != 4 { + t.Fatal("used not being incremented") + } + if e.keyCosts[2] != 2 { + t.Fatal("keyCosts not being updated") + } +} + +func TestSampledLFUDel(t *testing.T) { + e := newSampledLFU(4) + e.add(1, 1) + e.add(2, 2) + e.del(2) + if e.used != 1 { + t.Fatal("del not updating used field") + } + if _, ok := e.keyCosts[2]; ok { + t.Fatal("del not deleting value from keyCosts") + } + e.del(4) +} + +func TestSampledLFUUpdate(t *testing.T) { + e := newSampledLFU(4) + e.add(1, 1) + if !e.updateIfHas(1, 2) { + t.Fatal("update should be possible") + } + if e.used != 2 { + t.Fatal("update not changing used field") + } + if e.updateIfHas(2, 2) { + t.Fatal("update shouldn't be possible") + } +} + +func TestSampledLFUClear(t *testing.T) { + e := newSampledLFU(4) + e.add(1, 1) + e.add(2, 2) + e.add(3, 1) + e.clear() + if len(e.keyCosts) != 0 || e.used != 0 { + t.Fatal("clear not deleting keyCosts or zeroing used field") + } +} + +func TestSampledLFURoom(t *testing.T) { + e := newSampledLFU(16) + e.add(1, 1) + e.add(2, 2) + e.add(3, 3) + if e.roomLeft(4) != 6 { + t.Fatal("roomLeft returning wrong value") + } +} + +func TestSampledLFUSample(t *testing.T) { + e := newSampledLFU(16) + e.add(4, 4) + e.add(5, 5) + sample := e.fillSample([]*policyPair{ + {1, 1}, + {2, 2}, + {3, 3}, + }) + k := sample[len(sample)-1].key + if len(sample) != 5 || k == 1 || k == 2 || k == 3 { + t.Fatal("fillSample not filling properly") + } + if len(sample) != len(e.fillSample(sample)) { + t.Fatal("fillSample mutating full sample") + } + e.del(5) + if sample = e.fillSample(sample[:len(sample)-2]); len(sample) != 4 { + t.Fatal("fillSample not returning sample properly") + } +} + +func TestTinyLFUIncrement(t *testing.T) { + a := newTinyLFU(4) + a.Increment(1) + a.Increment(1) + a.Increment(1) + if !a.door.Has(1) { + t.Fatal("doorkeeper bit not set") + } + if a.freq.Estimate(1) != 2 { + t.Fatal("incorrect counter value") + } + a.Increment(1) + if a.door.Has(1) { + t.Fatal("doorkeeper bit set after reset") + } + if a.freq.Estimate(1) != 1 { + t.Fatal("counter value not halved after reset") + } +} + +func TestTinyLFUEstimate(t *testing.T) { + a := newTinyLFU(8) + a.Increment(1) + a.Increment(1) + a.Increment(1) + if a.Estimate(1) != 3 { + t.Fatal("estimate value incorrect") + } + if a.Estimate(2) != 0 { + t.Fatal("estimate value should be 0") + } +} + +func TestTinyLFUPush(t *testing.T) { + a := newTinyLFU(16) + a.Push([]uint64{1, 2, 2, 3, 3, 3}) + if a.Estimate(1) != 1 || a.Estimate(2) != 2 || a.Estimate(3) != 3 { + t.Fatal("push didn't increment counters properly") + } + if a.incrs != 6 { + t.Fatal("incrs not being incremented") + } +} + +func TestTinyLFUClear(t *testing.T) { + a := newTinyLFU(16) + a.Push([]uint64{1, 3, 3, 3}) + a.clear() + if a.incrs != 0 || a.Estimate(3) != 0 { + t.Fatal("clear not clearing") + } } diff --git a/ring.go b/ring.go index 57d3fddf..aa76c899 100644 --- a/ring.go +++ b/ring.go @@ -18,13 +18,6 @@ package ristretto import ( "sync" - "sync/atomic" - "time" -) - -const ( - ringLossy byte = iota - ringLossless ) // ringConsumer is the user-defined object responsible for receiving and @@ -35,17 +28,16 @@ type ringConsumer interface { // ringStripe is a singular ring buffer that is not concurrent safe. type ringStripe struct { - consumer ringConsumer - data []uint64 - capacity int - busy int32 + cons ringConsumer + data []uint64 + capa int } -func newRingStripe(config *ringConfig) *ringStripe { +func newRingStripe(cons ringConsumer, capa int64) *ringStripe { return &ringStripe{ - consumer: config.Consumer, - data: make([]uint64, 0, config.Capacity), - capacity: int(config.Capacity), + cons: cons, + data: make([]uint64, 0, capa), + capa: int(capa), } } @@ -54,23 +46,16 @@ func newRingStripe(config *ringConfig) *ringStripe { func (s *ringStripe) Push(item uint64) { s.data = append(s.data, item) // if we should drain - if len(s.data) >= s.capacity { + if len(s.data) >= s.capa { // Send elements to consumer. Create a new one. - if s.consumer.Push(s.data) { - s.data = make([]uint64, 0, s.capacity) + if s.cons.Push(s.data) { + s.data = make([]uint64, 0, s.capa) } else { s.data = s.data[:0] } } } -// ringConfig is passed to newRingBuffer with parameters. -type ringConfig struct { - Consumer ringConsumer - Stripes int64 - Capacity int64 -} - // ringBuffer stores multiple buffers (stripes) and distributes Pushed items // between them to lower contention. // @@ -79,67 +64,29 @@ type ringConfig struct { type ringBuffer struct { stripes []*ringStripe pool *sync.Pool - push func(*ringBuffer, uint64) - rand int - mask int } -// newRingBuffer returns a striped ring buffer. The Type can be either LOSSY or -// LOSSLESS. LOSSY should provide better performance. The Consumer in ringConfig -// will be called when individual stripes are full and need to drain their -// elements. -func newRingBuffer(ringType byte, config *ringConfig) *ringBuffer { - if ringType == ringLossy { - // LOSSY buffers use a very simple sync.Pool for concurrently reusing - // stripes. We do lose some stripes due to GC (unheld items in sync.Pool - // are cleared), but the performance gains generally outweigh the small - // percentage of elements lost. The performance primarily comes from - // low-level runtime functions used in the standard library that aren't - // available to us (such as runtime_procPin()). - return &ringBuffer{ - pool: &sync.Pool{ - New: func() interface{} { return newRingStripe(config) }, - }, - push: pushLossy, - } - } - // begin LOSSLESS buffer handling - // - // unlike lossy, lossless manually handles all stripes - stripes := make([]*ringStripe, config.Stripes) - for i := range stripes { - stripes[i] = newRingStripe(config) - } +// newRingBuffer returns a striped ring buffer. The Consumer in ringConfig will +// be called when individual stripes are full and need to drain their elements. +func newRingBuffer(cons ringConsumer, capa int64) *ringBuffer { + // LOSSY buffers use a very simple sync.Pool for concurrently reusing + // stripes. We do lose some stripes due to GC (unheld items in sync.Pool + // are cleared), but the performance gains generally outweigh the small + // percentage of elements lost. The performance primarily comes from + // low-level runtime functions used in the standard library that aren't + // available to us (such as runtime_procPin()). return &ringBuffer{ - stripes: stripes, - mask: int(config.Stripes - 1), - rand: int(time.Now().UnixNano()), // random seed for picking stripes - push: pushLossless, + pool: &sync.Pool{ + New: func() interface{} { return newRingStripe(cons, capa) }, + }, } } // Push adds an element to one of the internal stripes and possibly drains if // the stripe becomes full. func (b *ringBuffer) Push(item uint64) { - b.push(b, item) -} - -func pushLossy(b *ringBuffer, item uint64) { // reuse or create a new stripe stripe := b.pool.Get().(*ringStripe) stripe.Push(item) b.pool.Put(stripe) } - -func pushLossless(b *ringBuffer, item uint64) { - // try to find an available stripe - for i := 0; ; i = (i + 1) & b.mask { - if atomic.CompareAndSwapInt32(&b.stripes[i].busy, 0, 1) { - // try to get exclusive lock on the stripe - b.stripes[i].Push(item) - // unlock - atomic.StoreInt32(&b.stripes[i].busy, 0) - return - } - } -} diff --git a/ring_test.go b/ring_test.go index 2b78efb5..7cbab671 100644 --- a/ring_test.go +++ b/ring_test.go @@ -1,135 +1,73 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package ristretto import ( + "sync" "testing" - "time" -) - -const ( - // LOSSLESS number of stripes to test with - RING_STRIPES = 16 - // LOSSY/LOSSLESS size of individual stripes - RING_CAPACITY = 128 ) -type BaseConsumer struct{} - -func (c *BaseConsumer) Push(items []uint64) bool { return true } - -type TestConsumer struct { +type testConsumer struct { push func([]uint64) + save bool } -func (c *TestConsumer) Push(items []uint64) bool { c.push(items); return true } +func (c *testConsumer) Push(items []uint64) bool { + if c.save { + c.push(items) + return true + } + return false +} -func TestRingLossy(t *testing.T) { - drainCount := 0 - buffer := newRingBuffer(ringLossy, &ringConfig{ - Consumer: &TestConsumer{ - push: func(items []uint64) { - drainCount++ - }, +func TestRingDrain(t *testing.T) { + drains := 0 + r := newRingBuffer(&testConsumer{ + push: func(items []uint64) { + drains++ }, - Capacity: 4, - }) + save: true, + }, 1) for i := 0; i < 100; i++ { - buffer.Push(uint64(i)) + r.Push(uint64(i)) } - // ideally we'd be able to check for a certain "drop percentage" here, but - // that may vary per platform and testing configuration. for example: if - // drainCount == 20 then we have 100% accuracy, but it's most likely around - // 13-20 due to dropping and unfilled rings. - if drainCount == 0 { - t.Fatal("drain error") + if drains != 100 { + t.Fatal("buffers shouldn't be dropped with BufferItems == 1") } } -func TestRingLossless(t *testing.T) { - drainCount := 0 - found := make(map[uint64]struct{}) - buffer := newRingBuffer(ringLossless, &ringConfig{ - Consumer: &TestConsumer{ - push: func(items []uint64) { - drainCount++ - for _, item := range items { - found[item] = struct{}{} - } - }, +func TestRingReset(t *testing.T) { + drains := 0 + r := newRingBuffer(&testConsumer{ + push: func(items []uint64) { + drains++ }, - Capacity: 4, - Stripes: 2, - }) - buffer.Push(1) - buffer.Push(2) - buffer.Push(3) - buffer.Push(4) - buffer.Push(5) - buffer.Push(6) - buffer.Push(7) - buffer.Push(8) - time.Sleep(5 * time.Millisecond) - if drainCount != 2 || len(found) != 8 { - t.Fatal("drain error") + save: false, + }, 4) + for i := 0; i < 100; i++ { + r.Push(uint64(i)) + } + if drains != 0 { + t.Fatal("testConsumer shouldn't be draining") } } -func BenchmarkRingLossy(b *testing.B) { - buffer := newRingBuffer(ringLossy, &ringConfig{ - Consumer: &BaseConsumer{}, - Capacity: RING_CAPACITY, - }) - item := uint64(1) - b.Run("single", func(b *testing.B) { - b.SetBytes(1) - for n := 0; n < b.N; n++ { - buffer.Push(item) - } - }) - b.Run("multiple", func(b *testing.B) { - b.SetBytes(1) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - buffer.Push(item) - } - }) - }) -} - -func BenchmarkRingLossless(b *testing.B) { - buffer := newRingBuffer(ringLossless, &ringConfig{ - Consumer: &BaseConsumer{}, - Stripes: RING_STRIPES, - Capacity: RING_CAPACITY, - }) - item := uint64(1) - b.Run("single", func(b *testing.B) { - b.SetBytes(1) - for n := 0; n < b.N; n++ { - buffer.Push(item) - } - }) - b.Run("multiple", func(b *testing.B) { - b.SetBytes(1) - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - buffer.Push(item) +func TestRingConsumer(t *testing.T) { + mu := &sync.Mutex{} + drainItems := make(map[uint64]struct{}) + r := newRingBuffer(&testConsumer{ + push: func(items []uint64) { + mu.Lock() + defer mu.Unlock() + for i := range items { + drainItems[items[i]] = struct{}{} } - }) - }) + }, + save: true, + }, 4) + for i := 0; i < 100; i++ { + r.Push(uint64(i)) + } + l := len(drainItems) + if l == 0 || l > 100 { + t.Fatal("drains not being processed correctly") + } } diff --git a/sim/sim_test.go b/sim/sim_test.go index 61e64067..8d061878 100644 --- a/sim/sim_test.go +++ b/sim/sim_test.go @@ -19,7 +19,6 @@ package sim import ( "bytes" "compress/gzip" - "math" "os" "testing" ) @@ -34,16 +33,8 @@ func TestZipfian(t *testing.T) { } m[k]++ } - maxVal, minVal := uint64(0), uint64(math.MaxUint64) - for _, v := range m { - if v < minVal { - minVal = v - } else if v > maxVal { - maxVal = v - } - } - if maxVal-minVal < 10 { - t.Fatal("zipf not skewed enough") + if len(m) == 0 || len(m) == 100 { + t.Fatal("zipfian not skewed") } } diff --git a/sketch.go b/sketch.go index c3aa9024..c2b8c0f4 100644 --- a/sketch.go +++ b/sketch.go @@ -26,6 +26,8 @@ package ristretto import ( "fmt" + "math/rand" + "time" ) // cmSketch is a Count-Min sketch implementation with 4-bit counters, heavily @@ -34,7 +36,8 @@ import ( // [1]: https://github.com/dgryski/go-tinylfu/blob/master/cm4.go type cmSketch struct { rows [cmDepth]cmRow - mask uint32 + seed [cmDepth]uint64 + mask uint64 } const ( @@ -48,12 +51,11 @@ func newCmSketch(numCounters int64) *cmSketch { } // get the next power of 2 for better cache performance numCounters = next2Power(numCounters) - // sketch with FNV-64a hashing algorithm - sketch := &cmSketch{ - mask: uint32(numCounters - 1), - } - // initialize rows of counters + sketch := &cmSketch{mask: uint64(numCounters - 1)} + // initialize rows of counters and seeds + source := rand.New(rand.NewSource(time.Now().UnixNano())) for i := 0; i < cmDepth; i++ { + sketch.seed[i] = source.Uint64() sketch.rows[i] = newCmRow(numCounters) } return sketch @@ -61,21 +63,18 @@ func newCmSketch(numCounters int64) *cmSketch { // Increment increments the count(ers) for the specified key. func (s *cmSketch) Increment(hashed uint64) { - l, r := uint32(hashed), uint32(hashed>>32) for i := range s.rows { - // increment the counter on each row - s.rows[i].increment((l + uint32(i)*r) & s.mask) + s.rows[i].increment((hashed ^ s.seed[i]) & s.mask) } } // Estimate returns the value of the specified key. func (s *cmSketch) Estimate(hashed uint64) int64 { - l, r := uint32(hashed), uint32(hashed>>32) min := byte(255) for i := range s.rows { - // find the smallest counter value from all the rows - if v := s.rows[i].get((l + uint32(i)*r) & s.mask); v < min { - min = v + val := s.rows[i].get((hashed ^ s.seed[i]) & s.mask) + if val < min { + min = val } } return int64(min) @@ -95,16 +94,6 @@ func (s *cmSketch) Clear() { } } -func (s *cmSketch) string() string { - var state string - for i := range s.rows { - state += " [ " - state += s.rows[i].string() - state += " ]\n" - } - return state -} - // cmRow is a row of bytes, with each byte holding two counters type cmRow []byte @@ -112,11 +101,11 @@ func newCmRow(numCounters int64) cmRow { return make(cmRow, numCounters/2) } -func (r cmRow) get(n uint32) byte { +func (r cmRow) get(n uint64) byte { return byte(r[n/2]>>((n&1)*4)) & 0x0f } -func (r cmRow) increment(n uint32) { +func (r cmRow) increment(n uint64) { // index of the counter i := n / 2 // shift distance (even 0, odd 4) @@ -144,12 +133,12 @@ func (r cmRow) clear() { } func (r cmRow) string() string { - var state string + s := "" for i := uint64(0); i < uint64(len(r)*2); i++ { - state += fmt.Sprintf("%02d ", (r[(i/2)]>>((i&1)*4))&0x0f) + s += fmt.Sprintf("%02d ", (r[(i/2)]>>((i&1)*4))&0x0f) } - state = state[:len(state)-1] - return state + s = s[:len(s)-1] + return s } // next2Power rounds x up to the next power of 2, if it's not already one. @@ -164,40 +153,3 @@ func next2Power(x int64) int64 { x++ return x } - -/* -// TODO -// -// Fingerprint Counting Bloom Filter (FP-CBF): lower false positive rates than -// basic CBF with little added complexity. -// -// https://doi.org/10.1016/j.ipl.2015.11.002 -type FPCBF struct { -} - -func (c *FPCBF) Push(keys []ring.Element) {} -func (c *FPCBF) Estimate(hashed int64) int64 { return 0 } - -// TODO -// -// d-left Counting Bloom Filter: based on d-left hashing which allows for much -// better space efficiency (usually saving a factor of 2 or more). -// -// https://link.springer.com/chapter/10.1007/11841036_61 -type DLCBF struct { -} - -func (c *DLCBF) Push(keys []ring.Element) {} -func (c *DLCBF) Estimate(hashed int64) int64 { return 0 } - -// TODO -// -// Bloom Clock: this might be a good route for keeping track of LRU information -// in a space efficient, probabilistic manner. -// -// https://arxiv.org/abs/1905.13064 -type BC struct{} - -func (c *BC) Push(keys []ring.Element) {} -func (c *BC) Estimate(hashed int64) int64 { return 0 } -*/ diff --git a/sketch_test.go b/sketch_test.go index b9936fde..ff3edd75 100644 --- a/sketch_test.go +++ b/sketch_test.go @@ -1,110 +1,87 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package ristretto import ( "testing" - - "github.com/dgraph-io/ristretto/z" ) -// sketch is a collection of approximate frequency counters. -type sketch interface { - // Increment increments the count(ers) for the specified key. - Increment(uint64) - // Estimate returns the value of the specified key. - Estimate(uint64) int64 - // Reset halves all counter values. - Reset() -} - -type TestSketch interface { - sketch - string() string +func TestSketch(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal("no panic with bad param numCounters") + } + }() + s := newCmSketch(5) + if s.mask != 7 { + t.Fatal("not rounding up to next power of 2") + } + newCmSketch(0) } -func GenerateSketchTest(create func() TestSketch) func(t *testing.T) { - return func(t *testing.T) { - s := create() - s.Increment(0) - s.Increment(0) - s.Increment(0) - s.Increment(0) - if s.Estimate(0) != 4 { - t.Fatal("increment/estimate error") - } - if s.Estimate(1) != 0 { - t.Fatal("neighbor corruption") +func TestSketchIncrement(t *testing.T) { + s := newCmSketch(16) + s.Increment(1) + s.Increment(5) + s.Increment(9) + for i := 0; i < cmDepth; i++ { + if s.rows[i].string() != s.rows[0].string() { + break } - s.Reset() - if s.Estimate(0) != 2 { - t.Fatal("reset error") - } - if s.Estimate(9) != 0 { - t.Fatal("neighbor corruption") + if i == cmDepth-1 { + t.Fatal("identical rows, bad seeding") } } } -func TestCM(t *testing.T) { - GenerateSketchTest(func() TestSketch { return newCmSketch(16) })(t) -} - -func GenerateSketchBenchmark(create func() TestSketch) func(b *testing.B) { - return func(b *testing.B) { - s := create() - b.Run("increment", func(b *testing.B) { - b.SetBytes(1) - b.ResetTimer() - for n := 0; n < b.N; n++ { - s.Increment(1) - } - }) - b.Run("estimate", func(b *testing.B) { - b.SetBytes(1) - b.ResetTimer() - for n := 0; n < b.N; n++ { - s.Estimate(1) - } - }) +func TestSketchEstimate(t *testing.T) { + s := newCmSketch(16) + s.Increment(1) + s.Increment(1) + if s.Estimate(1) != 2 { + t.Fatal("estimate should be 2") + } + if s.Estimate(0) != 0 { + t.Fatal("estimate should be 0") } } -func BenchmarkCM(b *testing.B) { - GenerateSketchBenchmark(func() TestSketch { return newCmSketch(16) })(b) +func TestSketchReset(t *testing.T) { + s := newCmSketch(16) + s.Increment(1) + s.Increment(1) + s.Increment(1) + s.Increment(1) + s.Reset() + if s.Estimate(1) != 2 { + t.Fatal("reset failed, estimate should be 2") + } } -func TestDoorkeeper(t *testing.T) { - d := z.NewBloomFilter(float64(1374), 0.01) - hash := z.MemHashString("*") - if d.Has(hash) { - t.Fatal("item exists but was never added") +func TestSketchClear(t *testing.T) { + s := newCmSketch(16) + for i := 0; i < 16; i++ { + s.Increment(uint64(i)) } - if d.AddIfNotHas(hash) != true { - t.Fatal("item didn't exist so Set() should return true") - } - if d.AddIfNotHas(hash) != false { - t.Fatal("item did exist so Set() should return false") + s.Clear() + for i := 0; i < 16; i++ { + if s.Estimate(uint64(i)) != 0 { + t.Fatal("clear failed") + } } - if !d.Has(hash) { - t.Fatal("item was added but Has() is false") +} + +func BenchmarkSketchIncrement(b *testing.B) { + s := newCmSketch(16) + b.SetBytes(1) + for n := 0; n < b.N; n++ { + s.Increment(1) } - d.Clear() - if d.Has(hash) { - t.Fatal("doorkeeper was reset but Has() returns true") +} + +func BenchmarkSketchEstimate(b *testing.B) { + s := newCmSketch(16) + s.Increment(1) + b.SetBytes(1) + for n := 0; n < b.N; n++ { + s.Estimate(1) } } diff --git a/store_test.go b/store_test.go index 8a32759a..8b038f5b 100644 --- a/store_test.go +++ b/store_test.go @@ -1,110 +1,89 @@ -/* - * Copyright 2019 Dgraph Labs, Inc. and Contributors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - package ristretto -import ( - "testing" -) +import "testing" -func BenchmarkStoreLockedMap(b *testing.B) { - GenerateBench(func() store { return newLockedMap() })(b) +func TestStoreSetGet(t *testing.T) { + s := newStore() + s.Set(1, 2) + if val, ok := s.Get(1); (val == nil || !ok) || val.(int) != 2 { + t.Fatal("set/get error") + } + s.Set(1, 3) + if val, ok := s.Get(1); (val == nil || !ok) || val.(int) != 3 { + t.Fatal("set/get overwrite error") + } } -func GenerateBench(create func() store) func(*testing.B) { - return func(b *testing.B) { - b.Run("get ", func(b *testing.B) { - m := create() - m.Set(1, 1) - b.SetBytes(1) - b.ResetTimer() - b.RunParallel(func(pb *testing.PB) { - for pb.Next() { - m.Get(1) - } - }) - }) +func TestStoreDel(t *testing.T) { + s := newStore() + s.Set(1, 1) + s.Del(1) + if val, ok := s.Get(1); val != nil || ok { + t.Fatal("del error") } } -func TestStore(t *testing.T) { - GenerateTest(newStore)(t) +func TestStoreClear(t *testing.T) { + s := newStore() + for i := uint64(0); i < 1000; i++ { + s.Set(i, i) + } + s.Clear() + for i := uint64(0); i < 1000; i++ { + if val, ok := s.Get(i); val != nil || ok { + t.Fatal("clear operation failed") + } + } } -func TestStoreLockedMap(t *testing.T) { - GenerateTest(func() store { return newLockedMap() })(t) +func TestStoreUpdate(t *testing.T) { + s := newStore() + s.Set(1, 1) + if updated := s.Update(1, 2); !updated { + t.Fatal("value should have been updated") + } + if val, ok := s.Get(1); val == nil || !ok { + t.Fatal("value was deleted") + } + if val, ok := s.Get(1); val.(int) != 2 || !ok { + t.Fatal("value wasn't updated") + } + if updated := s.Update(2, 2); updated { + t.Fatal("value should not have been updated") + } + if val, ok := s.Get(2); val != nil || ok { + t.Fatal("value should not have been updated") + } } -func GenerateTest(create func() store) func(*testing.T) { - return func(t *testing.T) { - t.Run("set/get", func(t *testing.T) { - m := create() - m.Set(1, 1) - if val, _ := m.Get(1); val != nil && val.(int) != 1 { - t.Fatal("set-get error") - } - }) - t.Run("set", func(t *testing.T) { - m := create() - m.Set(1, 1) - // overwrite - m.Set(1, 2) - if val, _ := m.Get(1); val != nil && val.(int) != 2 { - t.Fatal("set update error") - } - }) - t.Run("del", func(t *testing.T) { - m := create() - m.Set(1, 1) - // delete item - m.Del(1) - if val, found := m.Get(1); val != nil || found { - t.Fatal("del error") - } - }) - t.Run("clear", func(t *testing.T) { - m := create() - // set a lot of values - for i := uint64(0); i < 1000; i++ { - m.Set(i, i) - } - // clear - m.Clear() - // check if any of the values exist - for i := uint64(0); i < 1000; i++ { - if _, ok := m.Get(i); ok { - t.Fatal("clear operation failed") - } - } - }) - t.Run("update", func(t *testing.T) { - m := create() - m.Set(1, 1) - if updated := m.Update(1, 2); !updated { - t.Fatal("value should have been updated") - } - if val, _ := m.Get(1); val.(int) != 2 { - t.Fatal("value wasn't updated") - } - if updated := m.Update(2, 2); updated { - t.Fatal("value should not have been updated") - } - if val, found := m.Get(2); val != nil || found { - t.Fatal("value should not have been updated") - } - }) - } +func BenchmarkStoreGet(b *testing.B) { + s := newStore() + s.Set(1, 1) + b.SetBytes(1) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + s.Get(1) + } + }) +} + +func BenchmarkStoreSet(b *testing.B) { + s := newStore() + b.SetBytes(1) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + s.Set(1, 1) + } + }) +} + +func BenchmarkStoreUpdate(b *testing.B) { + s := newStore() + s.Set(1, 1) + b.SetBytes(1) + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + s.Update(1, 2) + } + }) } diff --git a/stress_test.go b/stress_test.go new file mode 100644 index 00000000..33222cf5 --- /dev/null +++ b/stress_test.go @@ -0,0 +1,156 @@ +package ristretto + +import ( + "container/heap" + "fmt" + "math/rand" + "runtime" + "sync" + "testing" + "time" + + "github.com/dgraph-io/ristretto/sim" +) + +func TestStressSetGet(t *testing.T) { + c, err := NewCache(&Config{ + NumCounters: 1000, + MaxCost: 100, + BufferItems: 64, + Metrics: true, + }) + if err != nil { + panic(err) + } + for i := 0; i < 100; i++ { + c.Set(i, i, 1) + } + time.Sleep(time.Millisecond) + wg := &sync.WaitGroup{} + for i := 0; i < runtime.GOMAXPROCS(0); i++ { + wg.Add(1) + go func() { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + for a := 0; a < 1000; a++ { + k := r.Int() % 10 + if val, ok := c.Get(k); val == nil || !ok { + err = fmt.Errorf("expected %d but got nil", k) + break + } else if val != nil && val.(int) != k { + err = fmt.Errorf("expected %d but got %d", k, val.(int)) + break + } + } + wg.Done() + }() + } + wg.Wait() + if err != nil { + t.Fatal(err) + } + if r := c.stats.Ratio(); r != 1.0 { + t.Fatalf("hit ratio should be 1.0 but got %.2f\n", r) + } +} + +func TestStressHitRatio(t *testing.T) { + key := sim.NewZipfian(1.0001, 1, 1000) + c, err := NewCache(&Config{ + NumCounters: 1000, + MaxCost: 100, + BufferItems: 64, + Metrics: true, + }) + if err != nil { + panic(err) + } + o := NewClairvoyant(100) + for i := 0; i < 10000; i++ { + k, err := key() + if err != nil { + panic(err) + } + if _, ok := o.Get(k); !ok { + o.Set(k, k, 1) + } + if _, ok := c.Get(k); !ok { + c.Set(k, k, 1) + } + } + t.Logf("actual: %.2f, optimal: %.2f", c.stats.Ratio(), o.Metrics().Ratio()) +} + +// Clairvoyant is a mock cache providing us with optimal hit ratios to compare +// with Ristretto's. It looks ahead and evicts the absolute least valuable item, +// which we try to approximate in a real cache. +type Clairvoyant struct { + capacity uint64 + hits map[uint64]uint64 + access []uint64 +} + +func NewClairvoyant(capacity uint64) *Clairvoyant { + return &Clairvoyant{ + capacity: capacity, + hits: make(map[uint64]uint64), + access: make([]uint64, 0), + } +} + +// Get just records the cache access so that we can later take this event into +// consideration when calculating the absolute least valuable item to evict. +func (c *Clairvoyant) Get(key interface{}) (interface{}, bool) { + c.hits[key.(uint64)]++ + c.access = append(c.access, key.(uint64)) + return nil, false +} + +// Set isn't important because it is only called after a Get (in the case of our +// hit ratio benchmarks, at least). +func (c *Clairvoyant) Set(key, value interface{}, cost int64) bool { + return false +} + +func (c *Clairvoyant) Metrics() *metrics { + stat := newMetrics() + look := make(map[uint64]struct{}, c.capacity) + data := &clairvoyantHeap{} + heap.Init(data) + for _, key := range c.access { + if _, has := look[key]; has { + stat.Add(hit, 0, 1) + continue + } + if uint64(data.Len()) >= c.capacity { + victim := heap.Pop(data) + delete(look, victim.(*clairvoyantItem).key) + } + stat.Add(miss, 0, 1) + look[key] = struct{}{} + heap.Push(data, &clairvoyantItem{key, c.hits[key]}) + } + return stat +} + +type clairvoyantItem struct { + key uint64 + hits uint64 +} + +type clairvoyantHeap []*clairvoyantItem + +func (h clairvoyantHeap) Len() int { return len(h) } +func (h clairvoyantHeap) Less(i, j int) bool { return h[i].hits < h[j].hits } +func (h clairvoyantHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *clairvoyantHeap) Push(x interface{}) { + *h = append(*h, x.(*clairvoyantItem)) +} + +func (h *clairvoyantHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +}