Skip to content

Commit

Permalink
Merge pull request #1 from micanzhang/feature/exclude_key
Browse files Browse the repository at this point in the history
drivers/middleware/gin: add ExcludedKey Option
  • Loading branch information
micanzhang authored Mar 27, 2020
2 parents 1366201 + 50850ec commit a7b8d8b
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 0 deletions.
6 changes: 6 additions & 0 deletions drivers/middleware/gin/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type Middleware struct {
OnError ErrorHandler
OnLimitReached LimitReachedHandler
KeyGetter KeyGetter
ExcludedKey ExcludedKey
}

// NewMiddleware return a new instance of a gin middleware.
Expand All @@ -37,6 +38,11 @@ func NewMiddleware(limiter *limiter.Limiter, options ...Option) gin.HandlerFunc
// Handle gin request.
func (middleware *Middleware) Handle(c *gin.Context) {
key := middleware.KeyGetter(c)
if middleware.ExcludedKey != nil && middleware.ExcludedKey(key) {
c.Next()
return
}

context, err := middleware.Limiter.Get(c, key)
if err != nil {
middleware.OnError(c, err)
Expand Down
30 changes: 30 additions & 0 deletions drivers/middleware/gin/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,4 +125,34 @@ func TestHTTPMiddleware(t *testing.T) {
is.Equal(http.StatusOK, resp.Code, strconv.FormatInt(i, 10))
}

//
// Test ExcludedKey
//
store = memory.NewStore()
is.NotZero(store)
counter = int64(0)
middleware = gin.NewMiddleware(limiter.New(store, rate),
gin.WithKeyGetter(func(c *libgin.Context) string {
v := atomic.AddInt64(&counter, 1)
return strconv.FormatInt(v%2, 10)
}),
gin.WithExcludedKey(gin.DefaultExcludedKey([]string{"1"})),
)
is.NotZero(middleware)

router = libgin.New()
router.Use(middleware)
router.GET("/", func(c *libgin.Context) {
c.String(http.StatusOK, "hello")
})
success = 20
for i := int64(1); i < clients; i++ {
resp := httptest.NewRecorder()
router.ServeHTTP(resp, request)
if i <= success || i%2 == 1 {
is.Equal(http.StatusOK, resp.Code, strconv.FormatInt(i, 10))
} else {
is.Equal(resp.Code, http.StatusTooManyRequests)
}
}
}
23 changes: 23 additions & 0 deletions drivers/middleware/gin/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,26 @@ func WithKeyGetter(KeyGetter KeyGetter) Option {
func DefaultKeyGetter(c *gin.Context) string {
return c.ClientIP()
}

// ExcludedKey is function type used to check whether the key should be excluded or not
type ExcludedKey func(key string) bool


// DefaultExcludedKey is the default function returns ExcludedKey
func DefaultExcludedKey(keys []string) ExcludedKey {
m := make(map[string]struct{}, len(keys))
for _, key := range keys {
m[key] = struct{}{}
}
return func(key string) bool {
_, ok := m[key]
return ok
}
}

// WithExcludedKey will configure the Middleware to use the given ExcludedKey.
func WithExcludedKey(fn ExcludedKey) Option {
return option(func(middleware *Middleware) {
middleware.ExcludedKey = fn
})
}

0 comments on commit a7b8d8b

Please sign in to comment.