From 840b30cba5f13e4dff2bd50502a56b4fb39b28a2 Mon Sep 17 00:00:00 2001 From: joel Date: Thu, 11 Apr 2024 22:40:32 +0800 Subject: [PATCH 1/4] feat: add request timeouts --- internal/api/api.go | 5 ++ internal/api/errorcodes.go | 1 + internal/api/errors.go | 16 +++---- internal/api/middleware.go | 85 ++++++++++++++++++++++++++++++++++ internal/api/saml.go | 2 + internal/conf/configuration.go | 11 +++-- 6 files changed, 107 insertions(+), 13 deletions(-) diff --git a/internal/api/api.go b/internal/api/api.go index 32a3d28fe1..9c60dbe7e1 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -98,6 +98,11 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati logger := observability.NewStructuredLogger(logrus.StandardLogger(), globalConfig) r := newRouter() + + if globalConfig.API.MaxRequestDuration > 0 { + r.UseBypass(api.timeoutMiddleware(globalConfig.API.MaxRequestDuration)) + } + r.Use(addRequestID(globalConfig)) // request tracing should be added only when tracing or metrics is enabled diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go index f71d09356a..036d747a50 100644 --- a/internal/api/errorcodes.go +++ b/internal/api/errorcodes.go @@ -77,4 +77,5 @@ const ( ErrorCodeHookTimeoutAfterRetry ErrorCode = "hook_timeout_after_retry" ErrorCodeHookPayloadOverSizeLimit ErrorCode = "hook_payload_over_size_limit" ErrorCodeHookPayloadUnknownSize ErrorCode = "hook_payload_unknown_size" + ErrorCodeRequestTimeout ErrorCode = "request_timeout" ) diff --git a/internal/api/errors.go b/internal/api/errors.go index cc6ba877b8..c111fd377d 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -207,7 +207,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { output.Message = e.Message output.Payload.Reasons = e.Reasons - if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil { + if jsonErr := sendJSON(w, http.StatusUnprocessableEntity, output); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } @@ -224,7 +224,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { output.Message = e.Message output.Payload.Reasons = e.Reasons - if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil { + if jsonErr := sendJSON(w, output.HTTPStatus, output); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } } @@ -252,7 +252,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { } } - if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil { + if jsonErr := sendJSON(w, e.HTTPStatus, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } } else { @@ -266,20 +266,20 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { // Provide better error messages for certain user-triggered Postgres errors. if pgErr := utilities.NewPostgresError(e.InternalError); pgErr != nil { - if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil { + if jsonErr := sendJSON(w, pgErr.HttpStatusCode, pgErr); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } return } - if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil { + if jsonErr := sendJSON(w, e.HTTPStatus, e); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } } case *OAuthError: log.WithError(e.Cause()).Info(e.Error()) - if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil { + if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } @@ -295,7 +295,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { Message: "Unexpected failure, please check server logs for more information", } - if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil { + if jsonErr := sendJSON(w, http.StatusInternalServerError, resp); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } } else { @@ -305,7 +305,7 @@ func HandleResponseError(err error, w http.ResponseWriter, r *http.Request) { Message: "Unexpected failure, please check server logs for more information", } - if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil { + if jsonErr := sendJSON(w, http.StatusInternalServerError, httpError); jsonErr != nil && jsonErr != context.DeadlineExceeded { HandleResponseError(jsonErr, w, r) } } diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 6a6d68a259..f0602de493 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "strings" + "sync/atomic" "time" "github.com/supabase/auth/internal/models" @@ -260,3 +261,87 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H }) } } + +// timeoutResponseWriter is a http.ResponseWriter that prevents subsequent +// writes after the context contained in it has exceeded the deadline. If a +// partial write occurs before the deadline is exceeded, but the writing is not +// complete it will allow further writes. +type timeoutResponseWriter struct { + ctx context.Context + w http.ResponseWriter + wrote int32 +} + +func (t *timeoutResponseWriter) Header() http.Header { + return t.w.Header() +} + +func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) { + if t.ctx.Err() == context.DeadlineExceeded { + if atomic.LoadInt32(&t.wrote) == 0 { + return 0, context.DeadlineExceeded + } + + // writing started before the deadline exceeded, but the + // deadline came in the middle, so letting the writes go + // through + } + + atomic.AddInt32(&t.wrote, 1) + + return t.w.Write(bytes) +} + +func (t *timeoutResponseWriter) WriteHeader(statusCode int) { + if t.ctx.Err() == context.DeadlineExceeded { + if atomic.LoadInt32(&t.wrote) == 0 { + return + } + + // writing started before the deadline exceeded, but the + // deadline came in the middle, so letting the writes go + // through + } + + atomic.AddInt32(&t.wrote, 1) + + t.w.WriteHeader(statusCode) +} + +func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), timeout) + defer cancel() + + timeoutWriter := &timeoutResponseWriter{ + w: w, + ctx: ctx, + } + + go func() { + <-ctx.Done() + + err := ctx.Err() + + if err == context.DeadlineExceeded { + if atomic.LoadInt32(&timeoutWriter.wrote) == 0 { + // writer wasn't written to, so we're sending the error payload + + httpError := &HTTPError{ + HTTPStatus: http.StatusGatewayTimeout, + ErrorCode: ErrorCodeRequestTimeout, + Message: "Processing this request timed out, please retry after a moment.", + } + + httpError = httpError.WithInternalError(err) + + HandleResponseError(httpError, w, r) + } + } + }() + + next.ServeHTTP(timeoutWriter, r.WithContext(ctx)) + }) + } +} diff --git a/internal/api/saml.go b/internal/api/saml.go index d936ff2fa0..41ee9e4f99 100644 --- a/internal/api/saml.go +++ b/internal/api/saml.go @@ -95,6 +95,8 @@ func (a *API) SAMLMetadata(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Content-Disposition", "attachment; filename=\"metadata.xml\"") } + w.WriteHeader(http.StatusOK) + _, err = w.Write(metadataXML) return err diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 81eab797fb..ac62c77e50 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -112,11 +112,12 @@ type MFAConfiguration struct { } type APIConfiguration struct { - Host string - Port string `envconfig:"PORT" default:"8081"` - Endpoint string - RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"` - ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL" required:"true"` + Host string + Port string `envconfig:"PORT" default:"8081"` + Endpoint string + RequestIDHeader string `envconfig:"REQUEST_ID_HEADER"` + ExternalURL string `json:"external_url" envconfig:"API_EXTERNAL_URL" required:"true"` + MaxRequestDuration time.Duration `json:"max_request_duration" split_words:"true" default:"10s"` } func (a *APIConfiguration) Validate() error { From 0a18776d75325e531ebf9bb9286778fcdb2aa983 Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Thu, 25 Apr 2024 14:59:34 +0800 Subject: [PATCH 2/4] fix: test timeout handler --- internal/api/middleware_test.go | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index e591121f86..8c91f0258d 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "net/url" "testing" + "time" jwt "github.com/golang-jwt/jwt" "github.com/stretchr/testify/assert" @@ -312,3 +313,25 @@ func TestFunctionHooksUnmarshalJSON(t *testing.T) { }) } } + +func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() { + ts.Config.API.MaxRequestDuration = 5 * time.Microsecond + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + w := httptest.NewRecorder() + + timeoutHandler := ts.API.timeoutMiddleware(ts.Config.API.MaxRequestDuration) + + slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Sleep for 1 second to simulate a slow handler which should trigger the timeout + time.Sleep(1 * time.Second) + ts.API.handler.ServeHTTP(w, r) + }) + timeoutHandler(slowHandler).ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusGatewayTimeout, w.Code) + + var data map[string]interface{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data)) + require.Equal(ts.T(), ErrorCodeRequestTimeout, data["error_code"]) + require.Equal(ts.T(), float64(504), data["code"]) + require.NotNil(ts.T(), data["msg"]) +} From 1ec8152278755645ab65aff3fe486b3ef2821d18 Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Thu, 25 Apr 2024 16:13:19 +0800 Subject: [PATCH 3/4] fix: use a mutex to sync concurrent writes --- internal/api/middleware.go | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/internal/api/middleware.go b/internal/api/middleware.go index f0602de493..94efd156b4 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -7,6 +7,7 @@ import ( "net/http" "net/url" "strings" + "sync" "sync/atomic" "time" @@ -270,13 +271,18 @@ type timeoutResponseWriter struct { ctx context.Context w http.ResponseWriter wrote int32 + mu sync.Mutex } func (t *timeoutResponseWriter) Header() http.Header { + t.mu.Lock() + defer t.mu.Unlock() return t.w.Header() } func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) { + t.mu.Lock() + defer t.mu.Unlock() if t.ctx.Err() == context.DeadlineExceeded { if atomic.LoadInt32(&t.wrote) == 0 { return 0, context.DeadlineExceeded @@ -287,12 +293,14 @@ func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) { // through } - atomic.AddInt32(&t.wrote, 1) + t.wrote = 1 return t.w.Write(bytes) } func (t *timeoutResponseWriter) WriteHeader(statusCode int) { + t.mu.Lock() + defer t.mu.Unlock() if t.ctx.Err() == context.DeadlineExceeded { if atomic.LoadInt32(&t.wrote) == 0 { return @@ -303,7 +311,7 @@ func (t *timeoutResponseWriter) WriteHeader(statusCode int) { // through } - atomic.AddInt32(&t.wrote, 1) + t.wrote = 1 t.w.WriteHeader(statusCode) } @@ -325,7 +333,9 @@ func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.H err := ctx.Err() if err == context.DeadlineExceeded { - if atomic.LoadInt32(&timeoutWriter.wrote) == 0 { + timeoutWriter.mu.Lock() + defer timeoutWriter.mu.Unlock() + if timeoutWriter.wrote == 0 { // writer wasn't written to, so we're sending the error payload httpError := &HTTPError{ From 9622991f2c8aa3007781aac01dec77e9500050be Mon Sep 17 00:00:00 2001 From: Kang Ming Date: Thu, 25 Apr 2024 16:19:28 +0800 Subject: [PATCH 4/4] chore: remove unnecessary change --- internal/api/saml.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/internal/api/saml.go b/internal/api/saml.go index 41ee9e4f99..d936ff2fa0 100644 --- a/internal/api/saml.go +++ b/internal/api/saml.go @@ -95,8 +95,6 @@ func (a *API) SAMLMetadata(w http.ResponseWriter, r *http.Request) error { w.Header().Set("Content-Disposition", "attachment; filename=\"metadata.xml\"") } - w.WriteHeader(http.StatusOK) - _, err = w.Write(metadataXML) return err