diff --git a/bucket.go b/bucket.go index df1dc98..e9e9a77 100644 --- a/bucket.go +++ b/bucket.go @@ -98,8 +98,10 @@ func (b *bucket[T]) deletePrefix(prefix string, deletables chan *Item[T]) int { }, deletables) } +// we expect the caller to have acquired a write lock func (b *bucket[T]) clear() { - b.Lock() + for _, item := range b.lookup { + item.promotions = -2 + } b.lookup = make(map[string]*Item[T]) - b.Unlock() } diff --git a/cache.go b/cache.go index 074ad56..0124caa 100644 --- a/cache.go +++ b/cache.go @@ -206,6 +206,24 @@ func (c *Cache[T]) bucket(key string) *bucket[T] { return c.buckets[h.Sum32()&c.bucketMask] } +func (c *Cache[T]) halted(fn func()) { + c.halt() + defer c.unhalt() + fn() +} + +func (c *Cache[T]) halt() { + for _, bucket := range c.buckets { + bucket.Lock() + } +} + +func (c *Cache[T]) unhalt() { + for _, bucket := range c.buckets { + bucket.Unlock() + } +} + func (c *Cache[T]) worker() { dropped := 0 cc := c.control @@ -236,11 +254,22 @@ func (c *Cache[T]) worker() { } msg.done <- struct{}{} case controlClear: - for _, bucket := range c.buckets { - bucket.clear() - } - c.size = 0 - c.list = NewList[*Item[T]]() + c.halted(func() { + promotables := c.promotables + for len(promotables) > 0 { + <-promotables + } + deletables := c.deletables + for len(deletables) > 0 { + <-deletables + } + + for _, bucket := range c.buckets { + bucket.clear() + } + c.size = 0 + c.list = NewList[*Item[T]]() + }) msg.done <- struct{}{} case controlGetSize: msg.res <- c.size diff --git a/cache_test.go b/cache_test.go index 050a1d8..bb89015 100644 --- a/cache_test.go +++ b/cache_test.go @@ -4,6 +4,7 @@ import ( "math/rand" "sort" "strconv" + "sync" "sync/atomic" "testing" "time" @@ -361,6 +362,40 @@ func Test_ConcurrentStop(t *testing.T) { } } +func Test_ConcurrentClearAndSet(t *testing.T) { + for i := 0; i < 100; i++ { + var stop atomic.Bool + var wg sync.WaitGroup + + cache := New(Configure[string]()) + r := func() { + for !stop.Load() { + cache.Set("a", "a", time.Minute) + } + wg.Done() + } + go r() + wg.Add(1) + cache.Clear() + stop.Store(true) + wg.Wait() + time.Sleep(time.Millisecond) + cache.SyncUpdates() + + known := make(map[string]struct{}) + for node := cache.list.Head; node != nil; node = node.Next { + known[node.Value.key] = struct{}{} + } + + for _, bucket := range cache.buckets { + for key := range bucket.lookup { + _, exists := known[key] + assert.True(t, exists) + } + } + } +} + type SizedItem struct { id int s int64 diff --git a/layeredbucket.go b/layeredbucket.go index 0d8aa8e..7a1b169 100644 --- a/layeredbucket.go +++ b/layeredbucket.go @@ -111,9 +111,8 @@ func (b *layeredBucket[T]) forEachFunc(primary string, matches func(key string, } } +// we expect the caller to have acquired a write lock func (b *layeredBucket[T]) clear() { - b.Lock() - defer b.Unlock() for _, bucket := range b.buckets { bucket.clear() } diff --git a/layeredcache.go b/layeredcache.go index ada096d..3a4e0f4 100644 --- a/layeredcache.go +++ b/layeredcache.go @@ -196,6 +196,24 @@ func (c *LayeredCache[T]) bucket(key string) *layeredBucket[T] { return c.buckets[h.Sum32()&c.bucketMask] } +func (c *LayeredCache[T]) halted(fn func()) { + c.halt() + defer c.unhalt() + fn() +} + +func (c *LayeredCache[T]) halt() { + for _, bucket := range c.buckets { + bucket.Lock() + } +} + +func (c *LayeredCache[T]) unhalt() { + for _, bucket := range c.buckets { + bucket.Unlock() + } +} + func (c *LayeredCache[T]) promote(item *Item[T]) { c.promotables <- item } @@ -230,11 +248,22 @@ func (c *LayeredCache[T]) worker() { } msg.done <- struct{}{} case controlClear: - for _, bucket := range c.buckets { - bucket.clear() + promotables := c.promotables + for len(promotables) > 0 { + <-promotables } - c.size = 0 - c.list = NewList[*Item[T]]() + deletables := c.deletables + for len(deletables) > 0 { + <-deletables + } + + c.halted(func() { + for _, bucket := range c.buckets { + bucket.clear() + } + c.size = 0 + c.list = NewList[*Item[T]]() + }) msg.done <- struct{}{} case controlGetSize: msg.res <- c.size