Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the performance of the Memory store #127

Merged
merged 7 commits into from
Dec 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 118 additions & 51 deletions drivers/store/memory/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func (cleaner *cleaner) Run(cache *Cache) {
// stopCleaner is a callback from GC used to stop cleaner goroutine.
func stopCleaner(wrapper *CacheWrapper) {
wrapper.cleaner.stop <- true
wrapper.cleaner = nil
}

// startCleaner will start a cleaner goroutine for given cache.
Expand All @@ -50,34 +51,61 @@ func startCleaner(cache *Cache, interval time.Duration) {
go cleaner.Run(cache)
}

// Counter is a simple counter with an optional expiration.
// Counter is a simple counter with an expiration.
type Counter struct {
thoas marked this conversation as resolved.
Show resolved Hide resolved
Value int64
Expiration int64
mutex sync.RWMutex
value int64
expiration int64
}

// Expired returns true if the counter has expired.
func (counter Counter) Expired() bool {
if counter.Expiration == 0 {
return false
func (counter *Counter) Expired() bool {
counter.mutex.RLock()
defer counter.mutex.RUnlock()

return counter.expiration == 0 || time.Now().UnixNano() > counter.expiration
}

// Load returns the value and the expiration of this counter.
// If the counter is expired, it will use the given expiration.
func (counter *Counter) Load(expiration int64) (int64, int64) {
counter.mutex.RLock()
defer counter.mutex.RUnlock()

if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration {
thoas marked this conversation as resolved.
Show resolved Hide resolved
return 0, expiration
}

return counter.value, counter.expiration
}

// Increment increments given value on this counter.
// If the counter is expired, it will use the given expiration.
// It returns its current value and expiration.
func (counter *Counter) Increment(value int64, expiration int64) (int64, int64) {
counter.mutex.Lock()
defer counter.mutex.Unlock()

if counter.expiration == 0 || time.Now().UnixNano() > counter.expiration {
thoas marked this conversation as resolved.
Show resolved Hide resolved
counter.value = value
counter.expiration = expiration
return counter.value, counter.expiration
}
return time.Now().UnixNano() > counter.Expiration

counter.value += value
return counter.value, counter.expiration
}

// Cache contains a collection of counters.
type Cache struct {
mutex sync.RWMutex
counters map[string]Counter
counters sync.Map
cleaner *cleaner
}

// NewCache returns a new cache.
func NewCache(cleanInterval time.Duration) *CacheWrapper {

cache := &Cache{
counters: map[string]Counter{},
}

cache := &Cache{}
wrapper := &CacheWrapper{Cache: cache}

if cleanInterval > 0 {
Expand All @@ -88,71 +116,110 @@ func NewCache(cleanInterval time.Duration) *CacheWrapper {
return wrapper
}

// LoadOrStore returns the existing counter for the key if present.
// Otherwise, it stores and returns the given counter.
// The loaded result is true if the counter was loaded, false if stored.
func (cache *Cache) LoadOrStore(key string, counter *Counter) (*Counter, bool) {
val, loaded := cache.counters.LoadOrStore(key, counter)
if val == nil {
return counter, false
}

actual := val.(*Counter)
return actual, loaded
}

// Load returns the counter stored in the map for a key, or nil if no counter is present.
// The ok result indicates whether counter was found in the map.
func (cache *Cache) Load(key string) (*Counter, bool) {
val, ok := cache.counters.Load(key)
if val == nil || !ok {
return nil, false
}
actual := val.(*Counter)
return actual, true
}

// Store sets the counter for a key.
func (cache *Cache) Store(key string, counter *Counter) {
cache.counters.Store(key, counter)
}

// Delete deletes the value for a key.
func (cache *Cache) Delete(key string) {
cache.counters.Delete(key)
}

// Range calls handler sequentially for each key and value present in the cache.
// If handler returns false, range stops the iteration.
func (cache *Cache) Range(handler func(key string, counter *Counter)) {
cache.counters.Range(func(k interface{}, v interface{}) bool {
if v == nil {
return true
}

key := k.(string)
counter := v.(*Counter)

handler(key, counter)

return true
})
}

// Increment increments given value on key.
// If key is undefined or expired, it will create it.
func (cache *Cache) Increment(key string, value int64, duration time.Duration) (int64, time.Time) {
cache.mutex.Lock()

counter, ok := cache.counters[key]
if !ok || counter.Expired() {
expiration := time.Now().Add(duration).UnixNano()
counter = Counter{
Value: value,
Expiration: expiration,
}

cache.counters[key] = counter
cache.mutex.Unlock()
expiration := time.Now().Add(duration).UnixNano()

// If counter is in cache, try to load it first.
counter, loaded := cache.Load(key)
if loaded {
value, expiration = counter.Increment(value, expiration)
return value, time.Unix(0, expiration)
}

value = counter.Value + value
counter.Value = value
expiration := counter.Expiration

cache.counters[key] = counter
cache.mutex.Unlock()
// If it's not in cache, try to atomically create it.
// We do that in two step to reduce memory allocation.
counter, loaded = cache.LoadOrStore(key, &Counter{
mutex: sync.RWMutex{},
value: value,
expiration: expiration,
})
if loaded {
value, expiration = counter.Increment(value, expiration)
return value, time.Unix(0, expiration)
}

// Otherwise, it has been created, return given value.
return value, time.Unix(0, expiration)
}

// Get returns key's value and expiration.
func (cache *Cache) Get(key string, duration time.Duration) (int64, time.Time) {
cache.mutex.RLock()
expiration := time.Now().Add(duration).UnixNano()

counter, ok := cache.counters[key]
if !ok || counter.Expired() {
expiration := time.Now().Add(duration).UnixNano()
cache.mutex.RUnlock()
counter, ok := cache.Load(key)
if !ok {
return 0, time.Unix(0, expiration)
}

value := counter.Value
expiration := counter.Expiration
cache.mutex.RUnlock()

value, expiration := counter.Load(expiration)
return value, time.Unix(0, expiration)
}

// Clean will deleted any expired keys.
func (cache *Cache) Clean() {
now := time.Now().UnixNano()

cache.mutex.Lock()
for key, counter := range cache.counters {
if now > counter.Expiration {
delete(cache.counters, key)
cache.Range(func(key string, counter *Counter) {
if counter.Expired() {
cache.Delete(key)
}
}
cache.mutex.Unlock()
})
}

// Reset changes the key's value and resets the expiration.
func (cache *Cache) Reset(key string, duration time.Duration) (int64, time.Time) {
cache.mutex.Lock()
delete(cache.counters, key)
cache.mutex.Unlock()
cache.Delete(key)

expiration := time.Now().Add(duration).UnixNano()
return 0, time.Unix(0, expiration)
Expand Down
2 changes: 1 addition & 1 deletion drivers/store/memory/cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ func TestCacheIncrementSequential(t *testing.T) {
func TestCacheIncrementConcurrent(t *testing.T) {
is := require.New(t)

goroutines := 300
goroutines := 200
ops := 500

expected := int64(0)
Expand Down
29 changes: 16 additions & 13 deletions drivers/store/memory/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@ package memory

import (
"context"
"fmt"
"time"

"github.com/ulule/limiter/v3"
"github.com/ulule/limiter/v3/drivers/store/common"
"github.com/ulule/limiter/v3/internal/bytebuffer"
)

// Store is the in-memory store.
Expand Down Expand Up @@ -35,33 +35,36 @@ func NewStoreWithOptions(options limiter.StoreOptions) limiter.Store {

// Get returns the limit for given identifier.
func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
now := time.Now()
buffer := bytebuffer.New()
defer buffer.Close()
buffer.Concat(store.Prefix, ":", key)

count, expiration := store.cache.Increment(key, 1, rate.Period)
count, expiration := store.cache.Increment(buffer.String(), 1, rate.Period)

lctx := common.GetContextFromState(now, rate, expiration, count)
lctx := common.GetContextFromState(time.Now(), rate, expiration, count)
return lctx, nil
}

// Peek returns the limit for given identifier, without modification on current values.
func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
now := time.Now()
buffer := bytebuffer.New()
defer buffer.Close()
buffer.Concat(store.Prefix, ":", key)

count, expiration := store.cache.Get(key, rate.Period)
count, expiration := store.cache.Get(buffer.String(), rate.Period)

lctx := common.GetContextFromState(now, rate, expiration, count)
lctx := common.GetContextFromState(time.Now(), rate, expiration, count)
return lctx, nil
}

// Reset returns the limit for given identifier.
func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
now := time.Now()
buffer := bytebuffer.New()
defer buffer.Close()
buffer.Concat(store.Prefix, ":", key)

count, expiration := store.cache.Reset(key, rate.Period)
count, expiration := store.cache.Reset(buffer.String(), rate.Period)

lctx := common.GetContextFromState(now, rate, expiration, count)
lctx := common.GetContextFromState(time.Now(), rate, expiration, count)
return lctx, nil
}
8 changes: 4 additions & 4 deletions drivers/store/memory/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ func TestMemoryStoreConcurrentAccess(t *testing.T) {
}))
}

func BenchmarkRedisStoreSequentialAccess(b *testing.B) {
func BenchmarkMemoryStoreSequentialAccess(b *testing.B) {
tests.BenchmarkStoreSequentialAccess(b, memory.NewStoreWithOptions(limiter.StoreOptions{
Prefix: "limiter:memory:sequential-benchmark",
CleanUpInterval: 1 * time.Second,
CleanUpInterval: 1 * time.Hour,
}))
}

func BenchmarkRedisStoreConcurrentAccess(b *testing.B) {
func BenchmarkMemoryStoreConcurrentAccess(b *testing.B) {
tests.BenchmarkStoreConcurrentAccess(b, memory.NewStoreWithOptions(limiter.StoreOptions{
Prefix: "limiter:memory:concurrent-benchmark",
CleanUpInterval: 1 * time.Second,
CleanUpInterval: 1 * time.Hour,
}))
}
14 changes: 4 additions & 10 deletions drivers/store/tests/tests.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,38 +145,32 @@ func TestStoreConcurrentAccess(t *testing.T, store limiter.Store) {

// BenchmarkStoreSequentialAccess executes a benchmark against a store without parallel setting.
func BenchmarkStoreSequentialAccess(b *testing.B, store limiter.Store) {
is := require.New(b)
ctx := context.Background()

limiter := limiter.New(store, limiter.Rate{
instance := limiter.New(store, limiter.Rate{
Limit: 100000,
Period: 10 * time.Second,
})

b.ResetTimer()
for i := 0; i < b.N; i++ {
lctx, err := limiter.Get(ctx, "foo")
is.NoError(err)
is.NotZero(lctx)
_, _ = instance.Get(ctx, "foo")
}
}

// BenchmarkStoreConcurrentAccess executes a benchmark against a store with parallel setting.
func BenchmarkStoreConcurrentAccess(b *testing.B, store limiter.Store) {
is := require.New(b)
ctx := context.Background()

limiter := limiter.New(store, limiter.Rate{
instance := limiter.New(store, limiter.Rate{
Limit: 100000,
Period: 10 * time.Second,
})

b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
lctx, err := limiter.Get(ctx, "foo")
is.NoError(err)
is.NotZero(lctx)
_, _ = instance.Get(ctx, "foo")
}
})
}
Loading