From 284b495ef22c0879d1862bf20e246e55f1c650b5 Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Thu, 25 Apr 2024 16:35:11 +0800 Subject: [PATCH] feat: add timeout middleware (#1529) A new middleware is introduced that enforces a strict timeout by using `context.WithTimeout()`. When the timeout is reached, a 504 JSON error with the `request_timeout` error code is sent. Anything that depends on the context is cancelled. --------- Co-authored-by: Kang Ming --- internal/api/api.go | 5 ++ internal/api/errorcodes.go | 1 + internal/api/errors.go | 16 +++--- internal/api/middleware.go | 95 +++++++++++++++++++++++++++++++++ internal/api/middleware_test.go | 23 ++++++++ internal/conf/configuration.go | 11 ++-- 6 files changed, 138 insertions(+), 13 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index 32a3d28fe1..9c60dbe7e1 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -98,6 +98,11 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati logger := observability.NewStructuredLogger(logrus.StandardLogger(), globalConfig) r := newRouter() + + if globalConfig.API.MaxRequestDuration > 0 { + r.UseBypass(api.timeoutMiddleware(globalConfig.API.MaxRequestDuration)) + } + r.Use(addRequestID(globalConfig)) // request tracing should be added only when tracing or metrics is enabled diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go index f71d09356a..036d747a50 100644 --- a/internal/api/errorcodes.go +++ b/internal/api/errorcodes.go @@ -77,4 +77,5 @@ const ( ErrorCodeHookTimeoutAfterRetry ErrorCode = "hook_timeout_after_retry" ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit" ErrorCodeHookPayloadUnknownSize ErrorCode = "hook_payload_unknown_size" + ErrorCodeRequestTimeout ErrorCode = "request_timeout" ) diff --git a/internal/api/errors.go b/internal/api/errors.go index cc6ba877b8..c111fd377d 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -207,7 +207,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { output.Message = e.Message output.Payload.Reasons = e.Reasons - if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil { + if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } @@ -224,7 +224,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { output.Message = e.Message output.Payload.Reasons = e.Reasons - if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil { + if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } } @@ -252,7 +252,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { } } - if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil { + if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } } else { @@ -266,20 +266,20 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { // Provide better error messages for certain user-triggered Postgres errors. if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { - if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil { + if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } return } - if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil { + if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } } case *OAuthError: log.WithError(e.Cause()).Info(e.Error()) - if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { + if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } @@ -295,7 +295,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { Message: "Unexpected failure, please check server logs for more information", } - if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil { + if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } } else { @@ -305,7 +305,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { Message: "Unexpected failure, please check server logs for more information", } - if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil { + if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } } diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 6a6d68a259..94efd156b4 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -7,6 +7,8 @@ import ( "net/http" "net/url" "strings" + "sync" + "sync/atomic" "time" "github.com/supabase/auth/internal/models" @@ -260,3 +262,96 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H }) } } + +// timeoutResponseWriter is a http.ResponseWriter that prevents subsequent +// writes after the context contained in it has exceeded the deadline. If a +// partial write occurs before the deadline is exceeded, but the writing is not +// complete it will allow further writes. +type timeoutResponseWriter struct { + ctx context.Context + w http.ResponseWriter + wrote int32 + mu sync.Mutex +} + +func (t *timeoutResponseWriter) Header() http.Header { + t.mu.Lock() + defer t.mu.Unlock() + return t.w.Header() +} + +func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) { + t.mu.Lock() + defer t.mu.Unlock() + if t.ctx.Err() == context.DeadlineExceeded { + if atomic.LoadInt32(&t.wrote) == 0 { + return 0, context.DeadlineExceeded + } + + // writing started before the deadline exceeded, but the + // deadline came in the middle, so letting the writes go + // through + } + + t.wrote = 1 + + return t.w.Write(bytes) +} + +func (t *timeoutResponseWriter) WriteHeader(statusCode int) { + t.mu.Lock() + defer t.mu.Unlock() + if t.ctx.Err() == context.DeadlineExceeded { + if atomic.LoadInt32(&t.wrote) == 0 { + return + } + + // writing started before the deadline exceeded, but the + // deadline came in the middle, so letting the writes go + // through + } + + t.wrote = 1 + + t.w.WriteHeader(statusCode) +} + +func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + + timeoutWriter := &timeoutResponseWriter{ + w: w, + ctx: ctx, + } + + go func() { + <-ctx.Done() + + err := ctx.Err() + + if err == context.DeadlineExceeded { + timeoutWriter.mu.Lock() + defer timeoutWriter.mu.Unlock() + if timeoutWriter.wrote == 0 { + // writer wasn't written to, so we're sending the error payload + + httpError := &HTTPError{ + HTTPStatus: http.StatusGatewayTimeout, + ErrorCode: ErrorCodeRequestTimeout, + Message: "Processing this request timed out, please retry after a moment.", + } + + httpError = httpError.WithInternalError(err) + + HandleResponseError(httpError, w, r) + } + } + }() + + next.ServeHTTP(timeoutWriter, r.WithContext(ctx)) + }) + } +} diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index e591121f86..8c91f0258d 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "net/url" "testing" + "time" jwt "github.com/golang-jwt/jwt" "github.com/stretchr/testify/assert" @@ -312,3 +313,25 @@ func TestFunctionHooksUnmarshalJSON(t *testing.T) { }) } } + +func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() { + ts.Config.API.MaxRequestDuration = 5 * time.Microsecond + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + w := httptest.NewRecorder() + + timeoutHandler := ts.API.timeoutMiddleware(ts.Config.API.MaxRequestDuration) + + slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Sleep for 1 second to simulate a slow handler which should trigger the timeout + time.Sleep(1 * time.Second) + ts.API.handler.ServeHTTP(w, r) + }) + timeoutHandler(slowHandler).ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusGatewayTimeout, w.Code) + + var data map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), ErrorCodeRequestTimeout, data["error_code"]) + require.Equal(ts.T(), float64(504), data["code"]) + require.NotNil(ts.T(), data["msg"]) +} diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 81eab797fb..ac62c77e50 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -112,11 +112,12 @@ type MFAConfiguration struct { } type APIConfiguration struct { - Host string - Port string `envconfig:"PORT" default:"8081"` - Endpoint string - RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"` - ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL" required:"true"` + Host string + Port string `envconfig:"PORT" default:"8081"` + Endpoint string + RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"` + ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL" required:"true"` + MaxRequestDuration time.Duration `json:"max_request_duration" split_words:"true" default:"10s"` } func (a *APIConfiguration) Validate() error {