diff --git a/ratelimit_test.go b/ratelimit_test.go index fa87489..84ae05f 100644 --- a/ratelimit_test.go +++ b/ratelimit_test.go @@ -14,13 +14,13 @@ import ( ) type runner interface { + // createLimiter builds a limiter with given options. + createLimiter(int, ...ratelimit.Option) ratelimit.Limiter // startTaking tries to Take() on passed in limiters in a loop/goroutine. startTaking(rls ...ratelimit.Limiter) // assertCountAt asserts the limiters have Taken() a number of times at the given time. // It's a thin wrapper around afterFunc to reduce boilerplate code. assertCountAt(d time.Duration, count int) - // getClock returns the test clock. - getClock() ratelimit.Clock // afterFunc executes a func at a given time. // not using clock.AfterFunc because andres-erbsen/clock misses a nap there. afterFunc(d time.Duration, fn func()) @@ -51,6 +51,12 @@ func runTest(t *testing.T, fn func(runner)) { r.clock.Add(r.maxDuration) } +// createLimiter builds a limiter with given options. +func (r *runnerImpl) createLimiter(rate int, opts ...ratelimit.Option) ratelimit.Limiter { + opts = append(opts, ratelimit.WithClock(r.clock)) + return ratelimit.New(rate, opts...) +} + // startTaking tries to Take() on passed in limiters in a loop/goroutine. func (r *runnerImpl) startTaking(rls ...ratelimit.Limiter) { r.goWait(func() { @@ -77,11 +83,6 @@ func (r *runnerImpl) assertCountAt(d time.Duration, count int) { }) } -// getClock return the test clock. -func (r *runnerImpl) getClock() ratelimit.Clock { - return r.clock -} - // afterFunc executes a func at a given time. func (r *runnerImpl) afterFunc(d time.Duration, fn func()) { if d > r.maxDuration { @@ -144,7 +145,7 @@ func TestUnlimited(t *testing.T) { func TestRateLimiter(t *testing.T) { runTest(t, func(r runner) { - rl := ratelimit.New(100, ratelimit.WithClock(r.getClock()), ratelimit.WithoutSlack) + rl := r.createLimiter(100, ratelimit.WithoutSlack) // Create copious counts concurrently. r.startTaking(rl) @@ -160,8 +161,8 @@ func TestRateLimiter(t *testing.T) { func TestDelayedRateLimiter(t *testing.T) { runTest(t, func(r runner) { - slow := ratelimit.New(10, ratelimit.WithClock(r.getClock())) - fast := ratelimit.New(100, ratelimit.WithClock(r.getClock())) + slow := r.createLimiter(10, ratelimit.WithoutSlack) + fast := r.createLimiter(100, ratelimit.WithoutSlack) // Run a slow startTaking r.startTaking(slow, fast)