From 51e97ad872ca1bf72ca66936e6bffaca3d051fc6 Mon Sep 17 00:00:00 2001 From: Martin Martinez Rivera Date: Tue, 28 Jan 2020 13:54:00 -0800 Subject: [PATCH] Sets with TTL (#122) Add a new method SetWithTTL that supports adding key-value pairs to ristretto that expire after the given duration. --- cache.go | 77 +++++++++++++++++++-------- cache_test.go | 97 +++++++++++++++++++++++++++++++++- store.go | 125 +++++++++++++++++++++++++++++++------------- store_test.go | 87 +++++++++++++++++++++++++------ ttl.go | 141 ++++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 453 insertions(+), 74 deletions(-) create mode 100644 ttl.go diff --git a/cache.go b/cache.go index 07012514..603647ab 100644 --- a/cache.go +++ b/cache.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "sync/atomic" + "time" "github.com/dgraph-io/ristretto/z" ) @@ -33,6 +34,8 @@ const ( setBufSize = 32 * 1024 ) +type onEvictFunc func(uint64, uint64, interface{}, int64) + // Cache is a thread-safe implementation of a hashmap with a TinyLFU admission // policy and a Sampled LFU eviction policy. You can use the same Cache instance // from as many goroutines as you want. @@ -48,7 +51,7 @@ type Cache struct { // contention. setBuf chan *item // onEvict is called for item evictions. - onEvict func(uint64, uint64, interface{}, int64) + onEvict onEvictFunc // KeyToHash function is used to customize the key hashing algorithm. // Each key will be hashed using the provided function. If keyToHash value // is not set, the default keyToHash function is used. @@ -57,6 +60,8 @@ type Cache struct { stop chan struct{} // cost calculates cost from a value. cost func(value interface{}) int64 + // cleanupTicker is used to periodically check for entries whose TTL has passed. + cleanupTicker *time.Ticker // Metrics contains a running log of important statistics like hits, misses, // and dropped items. Metrics *Metrics @@ -115,11 +120,12 @@ const ( // item is passed to setBuf so items can eventually be added to the cache. type item struct { - flag itemFlag - key uint64 - conflict uint64 - value interface{} - cost int64 + flag itemFlag + key uint64 + conflict uint64 + value interface{} + cost int64 + expiration time.Time } // NewCache returns a new Cache instance and any configuration errors, if any. @@ -134,14 +140,15 @@ func NewCache(config *Config) (*Cache, error) { } policy := newPolicy(config.NumCounters, config.MaxCost) cache := &Cache{ - store: newStore(), - policy: policy, - getBuf: newRingBuffer(policy, config.BufferItems), - setBuf: make(chan *item, setBufSize), - onEvict: config.OnEvict, - keyToHash: config.KeyToHash, - stop: make(chan struct{}), - cost: config.Cost, + store: newStore(), + policy: policy, + getBuf: newRingBuffer(policy, config.BufferItems), + setBuf: make(chan *item, setBufSize), + onEvict: config.OnEvict, + keyToHash: config.KeyToHash, + stop: make(chan struct{}), + cost: config.Cost, + cleanupTicker: time.NewTicker(time.Duration(bucketDurationSecs) * time.Second / 2), } if cache.keyToHash == nil { cache.keyToHash = z.KeyToHash @@ -184,20 +191,42 @@ func (c *Cache) Get(key interface{}) (interface{}, bool) { // the cost parameter to 0 and Coster will be ran when needed in order to find // the items true cost. func (c *Cache) Set(key, value interface{}, cost int64) bool { + return c.SetWithTTL(key, value, cost, 0*time.Second) +} + +// SetWithTTL works like Set but adds a key-value pair to the cache that will expire +// after the specified TTL (time to live) has passed. A zero value means the value never +// expires, which is identical to calling Set. A negative value is a no-op and the value +// is discarded. +func (c *Cache) SetWithTTL(key, value interface{}, cost int64, ttl time.Duration) bool { if c == nil || key == nil { return false } + + var expiration time.Time + switch { + case ttl == 0: + // No expiration. + break + case ttl < 0: + // Treat this a a no-op. + return false + default: + expiration = time.Now().Add(ttl) + } + keyHash, conflictHash := c.keyToHash(key) i := &item{ - flag: itemNew, - key: keyHash, - conflict: conflictHash, - value: value, - cost: cost, + flag: itemNew, + key: keyHash, + conflict: conflictHash, + value: value, + cost: cost, + expiration: expiration, } - // Attempt to immediately update hashmap value and set flag to update so the - // cost is eventually updated. - if c.store.Update(keyHash, conflictHash, i.value) { + // cost is eventually updated. The expiration must also be immediately updated + // to prevent items from being prematurely removed from the map. + if c.store.Update(i) { i.flag = itemUpdate } // Attempt to send item to policy. @@ -277,7 +306,7 @@ func (c *Cache) processItems() { case itemNew: victims, added := c.policy.Add(i.key, i.cost) if added { - c.store.Set(i.key, i.conflict, i.value) + c.store.Set(i) c.Metrics.add(keyAdd, i.key, 1) } for _, victim := range victims { @@ -294,6 +323,8 @@ func (c *Cache) processItems() { c.policy.Del(i.key) // Deals with metrics updates. c.store.Del(i.key, i.conflict) } + case <-c.cleanupTicker.C: + c.store.Cleanup(c.policy, c.onEvict) case <-c.stop: return } diff --git a/cache_test.go b/cache_test.go index d5422e4a..908219de 100644 --- a/cache_test.go +++ b/cache_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/dgraph-io/ristretto/z" + "github.com/stretchr/testify/require" ) var wait = time.Millisecond * 10 @@ -273,7 +274,12 @@ func TestCacheGet(t *testing.T) { panic(err) } key, conflict := z.KeyToHash(1) - c.store.Set(key, conflict, 1) + i := item{ + key: key, + conflict: conflict, + value: 1, + } + c.store.Set(&i) if val, ok := c.Get(1); val == nil || !ok { t.Fatal("get should be successful") } @@ -338,6 +344,73 @@ func TestCacheSet(t *testing.T) { } } +// retrySet calls SetWithTTL until the item is accepted by the cache. +func retrySet(t *testing.T, c *Cache, key, value int, cost int64, ttl time.Duration) { + for { + if set := c.SetWithTTL(key, value, cost, ttl); !set { + time.Sleep(wait) + continue + } + + time.Sleep(wait) + val, ok := c.Get(key) + require.True(t, ok) + require.NotNil(t, val) + require.Equal(t, value, val.(int)) + return + } +} + +func TestCacheSetWithTTL(t *testing.T) { + m := &sync.Mutex{} + evicted := make(map[uint64]struct{}) + c, err := NewCache(&Config{ + NumCounters: 100, + MaxCost: 10, + BufferItems: 64, + Metrics: true, + OnEvict: func(key, conflict uint64, value interface{}, cost int64) { + m.Lock() + defer m.Unlock() + evicted[key] = struct{}{} + }, + }) + require.NoError(t, err) + + retrySet(t, c, 1, 1, 1, time.Second) + + // Sleep to make sure the item has expired after execution resumes. + time.Sleep(2 * time.Second) + val, ok := c.Get(1) + require.False(t, ok) + require.Nil(t, val) + + // Sleep to ensure that the bucket where the item was stored has been cleared + // from the expiraton map. + time.Sleep(5 * time.Second) + m.Lock() + require.Equal(t, 1, len(evicted)) + _, ok = evicted[1] + require.True(t, ok) + m.Unlock() + + // Verify that expiration times are overwritten. + retrySet(t, c, 2, 1, 1, time.Second) + retrySet(t, c, 2, 2, 1, 100*time.Second) + time.Sleep(3 * time.Second) + val, ok = c.Get(2) + require.True(t, ok) + require.Equal(t, 2, val.(int)) + + // Verify that entries with no expiration are overwritten. + retrySet(t, c, 3, 1, 1, 0) + retrySet(t, c, 3, 2, 1, time.Second) + time.Sleep(3 * time.Second) + val, ok = c.Get(3) + require.False(t, ok) + require.Nil(t, val) +} + func TestCacheDel(t *testing.T) { c, err := NewCache(&Config{ NumCounters: 100, @@ -361,6 +434,23 @@ func TestCacheDel(t *testing.T) { c.Del(1) } +func TestCacheDelWithTTL(t *testing.T) { + c, err := NewCache(&Config{ + NumCounters: 100, + MaxCost: 10, + BufferItems: 64, + }) + require.NoError(t, err) + retrySet(t, c, 3, 1, 1, 10*time.Second) + time.Sleep(1 * time.Second) + // Delete the item + c.Del(3) + // Ensure the key is deleted. + val, ok := c.Get(3) + require.False(t, ok) + require.Nil(t, val) +} + func TestCacheClear(t *testing.T) { c, err := NewCache(&Config{ NumCounters: 100, @@ -524,3 +614,8 @@ func TestCacheMetricsClear(t *testing.T) { c.Metrics = nil c.Metrics.Clear() } + +func init() { + // Set bucketSizeSecs to 1 to avoid waiting too much during the tests. + bucketDurationSecs = 1 +} diff --git a/store.go b/store.go index 3f3fbce4..f75c23ea 100644 --- a/store.go +++ b/store.go @@ -18,12 +18,14 @@ package ristretto import ( "sync" + "time" ) type storeItem struct { - key uint64 - conflict uint64 - value interface{} + key uint64 + conflict uint64 + value interface{} + expiration time.Time } // store is the interface fulfilled by all hash map implementations in this @@ -35,14 +37,19 @@ type storeItem struct { type store interface { // Get returns the value associated with the key parameter. Get(uint64, uint64) (interface{}, bool) + // Expiration returns the expiration time for this key. + Expiration(uint64) time.Time // Set adds the key-value pair to the Map or updates the value if it's - // already present. - Set(uint64, uint64, interface{}) + // already present. The key-value pair is passed as a pointer to an + // item object. + Set(*item) // Del deletes the key-value pair from the Map. Del(uint64, uint64) (uint64, interface{}) // Update attempts to update the key with a new value and returns true if // successful. - Update(uint64, uint64, interface{}) bool + Update(*item) bool + // Cleanup removes items that have an expired TTL. + Cleanup(policy policy, onEvict onEvictFunc) // Clear clears all contents of the store. Clear() } @@ -55,33 +62,48 @@ func newStore() store { const numShards uint64 = 256 type shardedMap struct { - shards []*lockedMap + shards []*lockedMap + expiryMap *expirationMap } func newShardedMap() *shardedMap { sm := &shardedMap{ - shards: make([]*lockedMap, int(numShards)), + shards: make([]*lockedMap, int(numShards)), + expiryMap: newExpirationMap(), } for i := range sm.shards { - sm.shards[i] = newLockedMap() + sm.shards[i] = newLockedMap(sm.expiryMap) } return sm } func (sm *shardedMap) Get(key, conflict uint64) (interface{}, bool) { - return sm.shards[key%numShards].Get(key, conflict) + return sm.shards[key%numShards].get(key, conflict) } -func (sm *shardedMap) Set(key, conflict uint64, value interface{}) { - sm.shards[key%numShards].Set(key, conflict, value) +func (sm *shardedMap) Expiration(key uint64) time.Time { + return sm.shards[key%numShards].Expiration(key) +} + +func (sm *shardedMap) Set(i *item) { + if i == nil { + // If item is nil make this Set a no-op. + return + } + + sm.shards[i.key%numShards].Set(i) } func (sm *shardedMap) Del(key, conflict uint64) (uint64, interface{}) { return sm.shards[key%numShards].Del(key, conflict) } -func (sm *shardedMap) Update(key, conflict uint64, value interface{}) bool { - return sm.shards[key%numShards].Update(key, conflict, value) +func (sm *shardedMap) Update(newItem *item) bool { + return sm.shards[newItem.key%numShards].Update(newItem) +} + +func (sm *shardedMap) Cleanup(policy policy, onEvict onEvictFunc) { + sm.expiryMap.cleanup(sm, policy, onEvict) } func (sm *shardedMap) Clear() { @@ -93,15 +115,17 @@ func (sm *shardedMap) Clear() { type lockedMap struct { sync.RWMutex data map[uint64]storeItem + em *expirationMap } -func newLockedMap() *lockedMap { +func newLockedMap(em *expirationMap) *lockedMap { return &lockedMap{ data: make(map[uint64]storeItem), + em: em, } } -func (m *lockedMap) Get(key, conflict uint64) (interface{}, bool) { +func (m *lockedMap) get(key, conflict uint64) (interface{}, bool) { m.RLock() item, ok := m.data[key] m.RUnlock() @@ -111,29 +135,51 @@ func (m *lockedMap) Get(key, conflict uint64) (interface{}, bool) { if conflict != 0 && (conflict != item.conflict) { return nil, false } + + // Handle expired items. + if !item.expiration.IsZero() && time.Now().After(item.expiration) { + return nil, false + } return item.value, true } -func (m *lockedMap) Set(key, conflict uint64, value interface{}) { +func (m *lockedMap) Expiration(key uint64) time.Time { + m.RLock() + defer m.RUnlock() + return m.data[key].expiration +} + +func (m *lockedMap) Set(i *item) { + if i == nil { + // If the item is nil make this Set a no-op. + return + } + m.Lock() - item, ok := m.data[key] - if !ok { - m.data[key] = storeItem{ - key: key, - conflict: conflict, - value: value, + item, ok := m.data[i.key] + + if ok { + m.em.update(i.key, i.conflict, item.expiration, i.expiration) + } else { + m.em.add(i.key, i.conflict, i.expiration) + m.data[i.key] = storeItem{ + key: i.key, + conflict: i.conflict, + value: i.value, + expiration: i.expiration, } m.Unlock() return } - if conflict != 0 && (conflict != item.conflict) { + if i.conflict != 0 && (i.conflict != item.conflict) { m.Unlock() return } - m.data[key] = storeItem{ - key: key, - conflict: conflict, - value: value, + m.data[i.key] = storeItem{ + key: i.key, + conflict: i.conflict, + value: i.value, + expiration: i.expiration, } m.Unlock() } @@ -149,27 +195,36 @@ func (m *lockedMap) Del(key, conflict uint64) (uint64, interface{}) { m.Unlock() return 0, nil } + + if !item.expiration.IsZero() { + m.em.del(key, item.expiration) + } + delete(m.data, key) m.Unlock() return item.conflict, item.value } -func (m *lockedMap) Update(key, conflict uint64, value interface{}) bool { +func (m *lockedMap) Update(newItem *item) bool { m.Lock() - item, ok := m.data[key] + item, ok := m.data[newItem.key] if !ok { m.Unlock() return false } - if conflict != 0 && (conflict != item.conflict) { + if newItem.conflict != 0 && (newItem.conflict != item.conflict) { m.Unlock() return false } - m.data[key] = storeItem{ - key: key, - conflict: conflict, - value: value, + + m.em.update(newItem.key, newItem.conflict, item.expiration, newItem.expiration) + m.data[newItem.key] = storeItem{ + key: newItem.key, + conflict: newItem.conflict, + value: newItem.value, + expiration: newItem.expiration, } + m.Unlock() return true } diff --git a/store_test.go b/store_test.go index 03a21344..53ee984d 100644 --- a/store_test.go +++ b/store_test.go @@ -9,16 +9,27 @@ import ( func TestStoreSetGet(t *testing.T) { s := newStore() key, conflict := z.KeyToHash(1) - s.Set(key, conflict, 2) + i := item{ + key: key, + conflict: conflict, + value: 2, + } + s.Set(&i) if val, ok := s.Get(key, conflict); (val == nil || !ok) || val.(int) != 2 { t.Fatal("set/get error") } - s.Set(key, conflict, 3) + i.value = 3 + s.Set(&i) if val, ok := s.Get(key, conflict); (val == nil || !ok) || val.(int) != 3 { t.Fatal("set/get overwrite error") } key, conflict = z.KeyToHash(2) - s.Set(key, conflict, 2) + i = item{ + key: key, + conflict: conflict, + value: 2, + } + s.Set(&i) if val, ok := s.Get(key, conflict); !ok || val.(int) != 2 { t.Fatal("set/get nil key error") } @@ -27,7 +38,12 @@ func TestStoreSetGet(t *testing.T) { func TestStoreDel(t *testing.T) { s := newStore() key, conflict := z.KeyToHash(1) - s.Set(key, conflict, 1) + i := item{ + key: key, + conflict: conflict, + value: 1, + } + s.Set(&i) s.Del(key, conflict) if val, ok := s.Get(key, conflict); val != nil || ok { t.Fatal("del error") @@ -39,7 +55,12 @@ func TestStoreClear(t *testing.T) { s := newStore() for i := uint64(0); i < 1000; i++ { key, conflict := z.KeyToHash(i) - s.Set(key, conflict, i) + it := item{ + key: key, + conflict: conflict, + value: i, + } + s.Set(&it) } s.Clear() for i := uint64(0); i < 1000; i++ { @@ -53,8 +74,14 @@ func TestStoreClear(t *testing.T) { func TestStoreUpdate(t *testing.T) { s := newStore() key, conflict := z.KeyToHash(1) - s.Set(key, conflict, 1) - if updated := s.Update(key, conflict, 2); !updated { + i := item{ + key: key, + conflict: conflict, + value: 1, + } + s.Set(&i) + i.value = 2 + if updated := s.Update(&i); !updated { t.Fatal("value should have been updated") } if val, ok := s.Get(key, conflict); val == nil || !ok { @@ -63,14 +90,20 @@ func TestStoreUpdate(t *testing.T) { if val, ok := s.Get(key, conflict); val.(int) != 2 || !ok { t.Fatal("value wasn't updated") } - if !s.Update(key, conflict, 3) { + i.value = 3 + if !s.Update(&i) { t.Fatal("value should have been updated") } if val, ok := s.Get(key, conflict); val.(int) != 3 || !ok { t.Fatal("value wasn't updated") } key, conflict = z.KeyToHash(2) - if updated := s.Update(key, conflict, 2); updated { + i = item{ + key: key, + conflict: conflict, + value: 2, + } + if updated := s.Update(&i); updated { t.Fatal("value should not have been updated") } if val, ok := s.Get(key, conflict); val != nil || ok { @@ -90,11 +123,16 @@ func TestStoreCollision(t *testing.T) { if val, ok := s.Get(1, 1); val != nil || ok { t.Fatal("collision should return nil") } - s.Set(1, 1, 2) + i := item{ + key: 1, + conflict: 1, + value: 2, + } + s.Set(&i) if val, ok := s.Get(1, 0); !ok || val == nil || val.(int) == 2 { t.Fatal("collision should prevent Set update") } - if s.Update(1, 1, 2) { + if s.Update(&i) { t.Fatal("collision should prevent Update") } if val, ok := s.Get(1, 0); !ok || val == nil || val.(int) == 2 { @@ -109,7 +147,12 @@ func TestStoreCollision(t *testing.T) { func BenchmarkStoreGet(b *testing.B) { s := newStore() key, conflict := z.KeyToHash(1) - s.Set(key, conflict, 1) + i := item{ + key: key, + conflict: conflict, + value: 1, + } + s.Set(&i) b.SetBytes(1) b.RunParallel(func(pb *testing.PB) { for pb.Next() { @@ -124,7 +167,12 @@ func BenchmarkStoreSet(b *testing.B) { b.SetBytes(1) b.RunParallel(func(pb *testing.PB) { for pb.Next() { - s.Set(key, conflict, 1) + i := item{ + key: key, + conflict: conflict, + value: 1, + } + s.Set(&i) } }) } @@ -132,11 +180,20 @@ func BenchmarkStoreSet(b *testing.B) { func BenchmarkStoreUpdate(b *testing.B) { s := newStore() key, conflict := z.KeyToHash(1) - s.Set(key, conflict, 1) + i := item{ + key: key, + conflict: conflict, + value: 1, + } + s.Set(&i) b.SetBytes(1) b.RunParallel(func(pb *testing.PB) { for pb.Next() { - s.Update(key, conflict, 2) + s.Update(&item{ + key: key, + conflict: conflict, + value: 2, + }) } }) } diff --git a/ttl.go b/ttl.go new file mode 100644 index 00000000..02864f7e --- /dev/null +++ b/ttl.go @@ -0,0 +1,141 @@ +/* + * Copyright 2020 Dgraph Labs, Inc. and Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package ristretto + +import ( + "sync" + "time" +) + +var ( + // TODO: find the optimal value or make it configurable. + bucketDurationSecs = int64(5) +) + +func storageBucket(t time.Time) int64 { + return (t.Unix() / bucketDurationSecs) + 1 +} + +func cleanupBucket(t time.Time) int64 { + // The bucket to cleanup is always behind the storage bucket by one so that + // no elements in that bucket (which might not have expired yet) are deleted. + return storageBucket(t) - 1 +} + +// bucket type is a map of key to conflict. +type bucket map[uint64]uint64 + +// expirationMap is a map of bucket number to the corresponding bucket. +type expirationMap struct { + sync.RWMutex + buckets map[int64]bucket +} + +func newExpirationMap() *expirationMap { + return &expirationMap{ + buckets: make(map[int64]bucket), + } +} + +func (m *expirationMap) add(key, conflict uint64, expiration time.Time) { + if m == nil { + return + } + + // Items that don't expire don't need to be in the expiration map. + if expiration.IsZero() { + return + } + + bucketNum := storageBucket(expiration) + m.Lock() + defer m.Unlock() + + b, ok := m.buckets[bucketNum] + if !ok { + b = make(bucket) + m.buckets[bucketNum] = b + } + b[key] = conflict +} + +func (m *expirationMap) update(key, conflict uint64, oldExpTime, newExpTime time.Time) { + if m == nil { + return + } + + m.Lock() + defer m.Unlock() + + oldBucketNum := storageBucket(oldExpTime) + oldBucket, ok := m.buckets[oldBucketNum] + if ok { + delete(oldBucket, key) + } + + newBucketNum := storageBucket(newExpTime) + newBucket, ok := m.buckets[newBucketNum] + if !ok { + newBucket = make(bucket) + m.buckets[newBucketNum] = newBucket + } + newBucket[key] = conflict +} + +func (m *expirationMap) del(key uint64, expiration time.Time) { + if m == nil { + return + } + + bucketNum := storageBucket(expiration) + m.Lock() + defer m.Unlock() + _, ok := m.buckets[bucketNum] + if !ok { + return + } + delete(m.buckets[bucketNum], key) +} + +// cleanup removes all the items in the bucket that was just completed. It deletes +// those items from the store, and calls the onEvict function on those items. +// This function is meant to be called periodically. +func (m *expirationMap) cleanup(store store, policy policy, onEvict onEvictFunc) { + if m == nil { + return + } + + m.Lock() + now := time.Now() + bucketNum := cleanupBucket(now) + keys := m.buckets[bucketNum] + delete(m.buckets, bucketNum) + m.Unlock() + + for key, conflict := range keys { + // Sanity check. Verify that the store agrees that this key is expired. + if store.Expiration(key).After(now) { + continue + } + + _, value := store.Del(key, conflict) + cost := policy.Cost(key) + if onEvict != nil { + onEvict(key, conflict, value, cost) + } + } +}