Skip to content

Commit

Permalink
do not break Status api
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasJenicek committed Feb 27, 2024
1 parent 74791f6 commit f8c004e
Showing 1 changed file with 22 additions and 24 deletions.
46 changes: 22 additions & 24 deletions limiter.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package httprate

import (
"context"
"fmt"
"math"
"net/http"
Expand Down Expand Up @@ -68,28 +67,8 @@ func (l *rateLimiter) Counter() LimitCounter {
return l.limitCounter
}

func (l *rateLimiter) Status(ctx context.Context, key string) (bool, float64, error) {
t := time.Now().UTC()
currentWindow := t.Truncate(l.windowLength)
previousWindow := currentWindow.Add(-l.windowLength)

currCount, prevCount, err := l.limitCounter.Get(key, currentWindow, previousWindow)
if err != nil {
return false, 0, err
}

diff := t.Sub(currentWindow)
rate := float64(prevCount)*(float64(l.windowLength)-float64(diff))/float64(l.windowLength) + float64(currCount)

limit := l.requestLimit
if val := getRequestLimit(ctx); val > 0 {
limit = val
}

if rate > float64(limit) {
return false, rate, nil
}
return true, rate, nil
func (l *rateLimiter) Status(key string) (bool, float64, error) {
return l.calculateRate(key, l.requestLimit)
}

func (l *rateLimiter) Handler(next http.Handler) http.Handler {
Expand All @@ -112,7 +91,7 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
w.Header().Set("X-RateLimit-Reset", fmt.Sprintf("%d", currentWindow.Add(l.windowLength).Unix()))

l.mu.Lock()
_, rate, err := l.Status(ctx, key)
_, rate, err := l.calculateRate(key, limit)
if err != nil {
l.mu.Unlock()
http.Error(w, err.Error(), http.StatusPreconditionRequired)
Expand Down Expand Up @@ -143,6 +122,25 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
})
}

func (l *rateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) {
t := time.Now().UTC()
currentWindow := t.Truncate(l.windowLength)
previousWindow := currentWindow.Add(-l.windowLength)

currCount, prevCount, err := l.limitCounter.Get(key, currentWindow, previousWindow)
if err != nil {
return false, 0, err
}

diff := t.Sub(currentWindow)
rate := float64(prevCount)*(float64(l.windowLength)-float64(diff))/float64(l.windowLength) + float64(currCount)
if rate > float64(requestLimit) {
return false, rate, nil
}

return true, rate, nil
}

type localCounter struct {
counters map[uint64]*count
windowLength time.Duration
Expand Down

0 comments on commit f8c004e

Please sign in to comment.