Skip to content

Commit

Permalink
fix: add ip based limiter (supabase#1622)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?
* Adds ip-based rate limiting on all endpoints that send OTPs either
through email or phone with the config `GOTRUE_RATE_LIMIT_OTP`
* IP-based rate limiting should always come before the shared limiter,
so as to prevent the quota of the shared limiter from being consumed too
quickly by the same ip-address
  • Loading branch information
kangmingtay authored and LashaJini committed Nov 15, 2024
1 parent 12586f6 commit ca09d69
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 7 deletions.
55 changes: 48 additions & 7 deletions internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,14 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
r.With(sharedLimiter).With(api.requireAdminCredentials).Post("/invite", api.Invite)
r.With(sharedLimiter).With(api.verifyCaptcha).Route("/signup", func(r *router) {
// rate limit per hour
limiter := tollbooth.NewLimiter(api.config.RateLimitAnonymousUsers/(60*60), &limiter.ExpirableOptions{
limitAnonymousSignIns := tollbooth.NewLimiter(api.config.RateLimitAnonymousUsers/(60*60), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(int(api.config.RateLimitAnonymousUsers)).SetMethods([]string{"POST"})

limitSignups := tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30)

r.Post("/", func(w http.ResponseWriter, r *http.Request) error {
params := &SignupParams{}
if err := retrieveRequestParams(r, params); err != nil {
Expand All @@ -148,19 +153,50 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
if !api.config.External.AnonymousUsers.Enabled {
return unprocessableEntityError(ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled")
}
if _, err := api.limitHandler(limiter)(w, r); err != nil {
if _, err := api.limitHandler(limitAnonymousSignIns)(w, r); err != nil {
return err
}
return api.SignupAnonymously(w, r)
}

// apply ip-based rate limiting on otps
if _, err := api.limitHandler(limitSignups)(w, r); err != nil {
return err
}
// apply shared rate limiting on email / phone
if _, err := sharedLimiter(w, r); err != nil {
return err
}
return api.Signup(w, r)
})
})
r.With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)
r.With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend)
r.With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).With(api.requireEmailProvider).Post("/recover", api.Recover)

r.With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp)
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/resend", api.Resend)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/magiclink", api.MagicLink)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).With(api.verifyCaptcha).Post("/otp", api.Otp)

r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes.
Expand All @@ -187,7 +223,12 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne

r.With(api.requireAuthentication).Route("/user", func(r *router) {
r.Get("/", api.UserGet)
r.With(sharedLimiter).Put("/", api.UserUpdate)
r.With(api.limitHandler(
// Allow requests at the specified rate per 5 minutes
tollbooth.NewLimiter(api.config.RateLimitOtp/(60*5), &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
}).SetBurst(30),
)).With(sharedLimiter).Put("/", api.UserUpdate)

r.Route("/identities", func(r *router) {
r.Use(api.requireManualLinkingEnabled)
Expand Down
134 changes: 134 additions & 0 deletions internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"testing"
"time"

"github.com/didip/tollbooth/v5"
"github.com/didip/tollbooth/v5/limiter"
jwt "github.com/golang-jwt/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -356,3 +358,135 @@ func TestTimeoutResponseWriter(t *testing.T) {

require.Equal(t, w1.Result(), w2.Result())
}

func (ts *MiddlewareTestSuite) TestLimitHandler() {
ts.Config.RateLimitHeader = "X-Rate-Limit"
lmt := tollbooth.NewLimiter(5, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
})

okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
b, _ := json.Marshal(map[string]interface{}{"message": "ok"})
w.Write([]byte(b))
})

for i := 0; i < 5; i++ {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0")
w := httptest.NewRecorder()
ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)

var data map[string]interface{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
require.Equal(ts.T(), "ok", data["message"])
}

// 6th request should fail and return a rate limit exceeded error
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0")
w := httptest.NewRecorder()
ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusTooManyRequests, w.Code)
}

func (ts *MiddlewareTestSuite) TestLimitHandlerWithSharedLimiter() {
// setup config for shared limiter and ip-based limiter to work
ts.Config.RateLimitHeader = "X-Rate-Limit"
ts.Config.External.Email.Enabled = true
ts.Config.External.Phone.Enabled = true
ts.Config.Mailer.Autoconfirm = false
ts.Config.Sms.Autoconfirm = false

ipBasedLimiter := func(max float64) *limiter.Limiter {
return tollbooth.NewLimiter(max, &limiter.ExpirableOptions{
DefaultExpirationTTL: time.Hour,
})
}

okHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})

cases := []struct {
desc string
sharedLimiterConfig *conf.GlobalConfiguration
ipBasedLimiterConfig float64
body map[string]interface{}
expectedErrorCode string
}{
{
desc: "Exceed ip-based rate limit before shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: 10,
RateLimitSmsSent: 10,
},
ipBasedLimiterConfig: 1,
body: map[string]interface{}{
"email": "foo@example.com",
},
expectedErrorCode: ErrorCodeOverRequestRateLimit,
},
{
desc: "Exceed email shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: 1,
RateLimitSmsSent: 1,
},
ipBasedLimiterConfig: 10,
body: map[string]interface{}{
"email": "foo@example.com",
},
expectedErrorCode: ErrorCodeOverEmailSendRateLimit,
},
{
desc: "Exceed sms shared limiter",
sharedLimiterConfig: &conf.GlobalConfiguration{
RateLimitEmailSent: 1,
RateLimitSmsSent: 1,
},
ipBasedLimiterConfig: 10,
body: map[string]interface{}{
"phone": "123456789",
},
expectedErrorCode: ErrorCodeOverSMSSendRateLimit,
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
ts.Config.RateLimitEmailSent = c.sharedLimiterConfig.RateLimitEmailSent
ts.Config.RateLimitSmsSent = c.sharedLimiterConfig.RateLimitSmsSent
lmt := ts.API.limitHandler(ipBasedLimiter(c.ipBasedLimiterConfig))
sharedLimiter := ts.API.limitEmailOrPhoneSentHandler()

// get the minimum amount to reach the threshold just before the rate limit is exceeded
threshold := min(c.sharedLimiterConfig.RateLimitEmailSent, c.sharedLimiterConfig.RateLimitSmsSent, c.ipBasedLimiterConfig)
for i := 0; i < int(threshold); i++ {
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0")

w := httptest.NewRecorder()
lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
}

var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body))
req := httptest.NewRequest(http.MethodPost, "http://localhost", &buffer)
req.Header.Add(ts.Config.RateLimitHeader, "0.0.0.0")

// check if the rate limit is exceeded with the expected error code
w := httptest.NewRecorder()
lmt.handler(sharedLimiter.handler(okHandler)).ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusTooManyRequests, w.Code)

var data map[string]interface{}
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
require.Equal(ts.T(), c.expectedErrorCode, data["error_code"])
})
}
}
1 change: 1 addition & 0 deletions internal/conf/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ type GlobalConfiguration struct {
RateLimitTokenRefresh float64 `split_words:"true" default:"150"`
RateLimitSso float64 `split_words:"true" default:"30"`
RateLimitAnonymousUsers float64 `split_words:"true" default:"30"`
RateLimitOtp float64 `split_words:"true" default:"30"`

SiteURL string `json:"site_url" split_words:"true" required:"true"`
URIAllowList []string `json:"uri_allow_list" split_words:"true"`
Expand Down

0 comments on commit ca09d69

Please sign in to comment.