Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More generic API to work with x/time/rate #582

Merged
merged 2 commits into from
Jul 26, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 82 additions & 7 deletions ratelimit/token_bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,100 @@ var ErrLimited = errors.New("rate limit exceeded")
// limiter based on a token-bucket algorithm. Requests that would exceed the
// maximum request rate are simply rejected with an error.
func NewTokenBucketLimiter(tb *ratelimit.Bucket) endpoint.Middleware {
return NewErroringLimiter(NewAllower(tb))
}

// NewTokenBucketThrottler returns an endpoint.Middleware that acts as a
// request throttler based on a token-bucket algorithm. Requests that would
// exceed the maximum request rate are delayed via the parameterized sleep
// function. By default you may pass time.Sleep.
func NewTokenBucketThrottler(tb *ratelimit.Bucket, sleep func(time.Duration)) endpoint.Middleware {
// return NewDelayingLimiter(NewWaiter(tb))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this didn't work? Would be cool if we could express all the old functions in terms of the new ones!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do it, but then the sleep parameter gets ignored. I'll check in that change, so you can evaluate if you like it.

return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
if tb.TakeAvailable(1) == 0 {
sleep(tb.Take(1))
return next(ctx, request)
}
}
}

// Allower dictates whether or not a request is acceptable to run.
// The Limiter from "golang.org/x/time/rate" already implements this interface,
// one is able to use that in NewErroringLimiter without any modifications.
type Allower interface {
Allow() bool
}

// NewErroringLimiter returns an endpoint.Middleware that acts as a rate
// limiter. Requests that would exceed the
// maximum request rate are simply rejected with an error.
func NewErroringLimiter(limit Allower) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
if !limit.Allow() {
return nil, ErrLimited
}
return next(ctx, request)
}
}
}

// NewTokenBucketThrottler returns an endpoint.Middleware that acts as a
// request throttler based on a token-bucket algorithm. Requests that would
// exceed the maximum request rate are delayed via the parameterized sleep
// function. By default you may pass time.Sleep.
func NewTokenBucketThrottler(tb *ratelimit.Bucket, sleep func(time.Duration)) endpoint.Middleware {
// Waiter dictates how long a request must be delayed.
// The Limiter from "golang.org/x/time/rate" already implements this interface,
// one is able to use that in NewDelayingLimiter without any modifications.
type Waiter interface {
Wait(ctx context.Context) error
}

// NewDelayingLimiter returns an endpoint.Middleware that acts as a
// request throttler. Requests that would
// exceed the maximum request rate are delayed via the Waiter function
func NewDelayingLimiter(limit Waiter) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (interface{}, error) {
sleep(tb.Take(1))
if err := limit.Wait(ctx); err != nil {
return nil, err
}
return next(ctx, request)
}
}
}

// AllowerFunc is an adapter that lets a function operate as if
// it implements Allower
type AllowerFunc func() bool

// Allow makes the adapter implement Allower
func (f AllowerFunc) Allow() bool {
return f()
}

// NewAllower turns an existing ratelimit.Bucket into an API-compatible form
func NewAllower(tb *ratelimit.Bucket) Allower {
return AllowerFunc(func() bool {
return (tb.TakeAvailable(1) != 0)
})
}

// WaiterFunc is an adapter that lets a function operate as if
// it implements Waiter
type WaiterFunc func(ctx context.Context) error

// Wait makes the adapter implement Waiter
func (f WaiterFunc) Wait(ctx context.Context) error {
return f(ctx)
}

// NewWaiter turns an existing ratelimit.Bucket into an API-compatible form
func NewWaiter(tb *ratelimit.Bucket) Waiter {
return WaiterFunc(func(ctx context.Context) error {
dur := tb.Take(1)
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(dur):
// happy path
}
return nil
})
}
28 changes: 28 additions & 0 deletions ratelimit/token_bucket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ package ratelimit_test
import (
"context"
"math"
"strings"
"testing"
"time"

jujuratelimit "github.com/juju/ratelimit"

"github.com/go-kit/kit/endpoint"
"github.com/go-kit/kit/ratelimit"
"golang.org/x/time/rate"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move this import decl up with juju, above?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

)

func TestTokenBucketLimiter(t *testing.T) {
Expand Down Expand Up @@ -53,3 +55,29 @@ func testLimiter(t *testing.T, e endpoint.Endpoint, rate int) {
t.Errorf("rate=%d: want %v, have %v", rate, ratelimit.ErrLimited, err)
}
}

func TestXRateErroring(t *testing.T) {
limit := rate.NewLimiter(rate.Every(time.Minute), 1)
e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
testLimiter(t, ratelimit.NewErroringLimiter(limit)(e), 1)
}

func TestXRateDelaying(t *testing.T) {
limit := rate.NewLimiter(rate.Every(time.Minute), 1)
e := func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }
e = ratelimit.NewDelayingLimiter(limit)(e)

_, err := e(context.Background(), struct{}{})
if err != nil {
t.Errorf("unexpected: %v\n", err)
}

dur := 500 * time.Millisecond
ctx, cxl := context.WithTimeout(context.Background(), dur)
defer cxl()

_, err = e(ctx, struct{}{})
if !strings.Contains(err.Error(), "exceed context deadline") {
t.Errorf("expected timeout: %v\n", err)
}
}