diff --git a/limiter_atomic.go b/limiter_atomic.go index 241b368..afb75f8 100644 --- a/limiter_atomic.go +++ b/limiter_atomic.go @@ -63,15 +63,21 @@ func newAtomicBased(rate int, opts ...Option) *atomicLimiter { // Take blocks to ensure that the time spent between multiple // Take calls is on average time.Second/rate. func (t *atomicLimiter) Take() time.Time { - newState := state{} - taken := false + var ( + newState state + taken bool + interval time.Duration + ) for !taken { now := t.clock.Now() previousStatePointer := atomic.LoadPointer(&t.state) oldState := (*state)(previousStatePointer) - newState = state{} + newState = state{ + last: now, + sleepFor: oldState.sleepFor, + } newState.last = now // If this is our first request, then we allow it. @@ -93,9 +99,10 @@ func (t *atomicLimiter) Take() time.Time { } if newState.sleepFor > 0 { newState.last = newState.last.Add(newState.sleepFor) + interval, newState.sleepFor = newState.sleepFor, 0 } taken = atomic.CompareAndSwapPointer(&t.state, previousStatePointer, unsafe.Pointer(&newState)) } - t.clock.Sleep(newState.sleepFor) + t.clock.Sleep(interval) return newState.last } diff --git a/ratelimit_test.go b/ratelimit_test.go index 1a2e078..431244c 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -164,18 +164,20 @@ func TestDelayedRateLimiter(t *testing.T) { slow := r.createLimiter(10, ratelimit.WithoutSlack) fast := r.createLimiter(100, ratelimit.WithoutSlack) - // Run a slow startTaking r.startTaking(slow, fast) - // Accumulate slack for 10 seconds, r.afterFunc(20*time.Second, func() { - // Then start working. r.startTaking(fast) r.startTaking(fast) r.startTaking(fast) r.startTaking(fast) }) + // Slow limiter allows 10 per second, so 100. + r.assertCountAt(10*time.Second, 100) + // Another 10 seconds, so we're at 200. + r.assertCountAt(20*time.Second, 200) + // Now the fast limiter goes at 100/sec, so another 1000. r.assertCountAt(30*time.Second, 1200) }) } @@ -192,3 +194,24 @@ func TestPer(t *testing.T) { r.assertCountAt(2*time.Minute, 15) }) } + +func TestSlack(t *testing.T) { + runTest(t, func(r runner) { + slow := r.createLimiter(10, ratelimit.WithoutSlack) + // Defaults to 10 slack. + fast := r.createLimiter(100) + + r.startTaking(slow, fast) + + r.afterFunc(1*time.Second, func() { + r.startTaking(fast) + r.startTaking(fast) + }) + + // limiter with 10hz dominates here - we're just at 10. + r.assertCountAt(1*time.Second, 10) + // limiter with 100hz dominates, so we're at 110, + // but we get extra 10 from accumulated slack + r.assertCountAt(2*time.Second, 120) + }) +}