diff --git a/drivers/middleware/gin/middleware.go b/drivers/middleware/gin/middleware.go index 620b375..3a7b872 100644 --- a/drivers/middleware/gin/middleware.go +++ b/drivers/middleware/gin/middleware.go @@ -14,6 +14,7 @@ type Middleware struct { OnError ErrorHandler OnLimitReached LimitReachedHandler KeyGetter KeyGetter + ExcludedKey ExcludedKey } // NewMiddleware return a new instance of a gin middleware. @@ -37,6 +38,11 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) gin.HandlerFunc // Handle gin request. func (middleware *Middleware) Handle(c *gin.Context) { key := middleware.KeyGetter(c) + if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) { + c.Next() + return + } + context, err := middleware.Limiter.Get(c, key) if err != nil { middleware.OnError(c, err) diff --git a/drivers/middleware/gin/middleware_test.go b/drivers/middleware/gin/middleware_test.go index 029478d..ed187a4 100644 --- a/drivers/middleware/gin/middleware_test.go +++ b/drivers/middleware/gin/middleware_test.go @@ -125,4 +125,34 @@ func TestHTTPMiddleware(t *testing.T) { is.Equal(http.StatusOK, resp.Code, strconv.FormatInt(i, 10)) } + // + // Test ExcludedKey + // + store = memory.NewStore() + is.NotZero(store) + counter = int64(0) + middleware = gin.NewMiddleware(limiter.New(store, rate), + gin.WithKeyGetter(func(c *libgin.Context) string { + v := atomic.AddInt64(&counter, 1) + return strconv.FormatInt(v%2, 10) + }), + gin.WithExcludedKey(gin.DefaultExcludedKey([]string{"1"})), + ) + is.NotZero(middleware) + + router = libgin.New() + router.Use(middleware) + router.GET("/", func(c *libgin.Context) { + c.String(http.StatusOK, "hello") + }) + success = 20 + for i := int64(1); i < clients; i++ { + resp := httptest.NewRecorder() + router.ServeHTTP(resp, request) + if i <= success || i%2 == 1 { + is.Equal(http.StatusOK, resp.Code, strconv.FormatInt(i, 10)) + } else { + is.Equal(resp.Code, http.StatusTooManyRequests) + } + } } diff --git a/drivers/middleware/gin/options.go b/drivers/middleware/gin/options.go index 86066f1..f803045 100644 --- a/drivers/middleware/gin/options.go +++ b/drivers/middleware/gin/options.go @@ -62,3 +62,26 @@ func WithKeyGetter(KeyGetter KeyGetter) Option { func DefaultKeyGetter(c *gin.Context) string { return c.ClientIP() } + +// ExcludedKey is function type used to check whether the key should be excluded or not +type ExcludedKey func(key string) bool + + +// DefaultExcludedKey is the default function returns ExcludedKey +func DefaultExcludedKey(keys []string) ExcludedKey { + m := make(map[string]struct{}, len(keys)) + for _, key := range keys { + m[key] = struct{}{} + } + return func(key string) bool { + _, ok := m[key] + return ok + } +} + +// WithExcludedKey will configure the Middleware to use the given ExcludedKey. +func WithExcludedKey(fn ExcludedKey) Option { + return option(func(middleware *Middleware) { + middleware.ExcludedKey = fn + }) +}