diff --git a/internal/api/api.go b/internal/api/api.go index 51d443331..8613a05dc 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -102,7 +102,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne r.UseBypass(recoverer) if globalConfig.API.MaxRequestDuration > 0 { - r.UseBypass(api.timeoutMiddleware(globalConfig.API.MaxRequestDuration)) + r.UseBypass(timeoutMiddleware(globalConfig.API.MaxRequestDuration)) } // request tracing should be added only when tracing or metrics is enabled diff --git a/internal/api/middleware.go b/internal/api/middleware.go index 8360af0a0..5db550bf8 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -1,6 +1,7 @@ package api import ( + "bytes" "context" "encoding/json" "fmt" @@ -8,9 +9,9 @@ import ( "net/url" "strings" "sync" - "sync/atomic" "time" + "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" "github.com/supabase/auth/internal/security" @@ -263,95 +264,122 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H } } -// timeoutResponseWriter is a http.ResponseWriter that prevents subsequent -// writes after the context contained in it has exceeded the deadline. If a -// partial write occurs before the deadline is exceeded, but the writing is not -// complete it will allow further writes. +// timeoutResponseWriter is a http.ResponseWriter that queues up a response +// body to be sent if the serving completes before the context has exceeded its +// deadline. type timeoutResponseWriter struct { - ctx context.Context - w http.ResponseWriter - wrote int32 - mu sync.Mutex + sync.Mutex + + header http.Header + wroteHeader bool + snapHeader http.Header // snapshot of the header at the time WriteHeader was called + statusCode int + buf bytes.Buffer } func (t *timeoutResponseWriter) Header() http.Header { - t.mu.Lock() - defer t.mu.Unlock() - return t.w.Header() + t.Lock() + defer t.Unlock() + + return t.header } func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) { - t.mu.Lock() - defer t.mu.Unlock() - if t.ctx.Err() == context.DeadlineExceeded { - if atomic.LoadInt32(&t.wrote) == 0 { - return 0, context.DeadlineExceeded - } + t.Lock() + defer t.Unlock() - // writing started before the deadline exceeded, but the - // deadline came in the middle, so letting the writes go - // through + if !t.wroteHeader { + t.WriteHeader(http.StatusOK) } - t.wrote = 1 - - return t.w.Write(bytes) + return t.buf.Write(bytes) } func (t *timeoutResponseWriter) WriteHeader(statusCode int) { - t.mu.Lock() - defer t.mu.Unlock() - if t.ctx.Err() == context.DeadlineExceeded { - if atomic.LoadInt32(&t.wrote) == 0 { - return - } + t.Lock() + defer t.Unlock() + + if t.wroteHeader { + // ignore multiple calls to WriteHeader + // once WriteHeader has been called once, a snapshot of the header map is taken + // and saved in snapHeader to be used in finallyWrite + return + } + t.statusCode = statusCode + t.wroteHeader = true + t.snapHeader = t.header.Clone() +} + +func (t *timeoutResponseWriter) finallyWrite(w http.ResponseWriter) { + t.Lock() + defer t.Unlock() - // writing started before the deadline exceeded, but the - // deadline came in the middle, so letting the writes go - // through + dst := w.Header() + for k, vv := range t.snapHeader { + dst[k] = vv } - t.wrote = 1 + if !t.wroteHeader { + t.statusCode = http.StatusOK + } - t.w.WriteHeader(statusCode) + w.WriteHeader(t.statusCode) + if _, err := w.Write(t.buf.Bytes()); err != nil { + logrus.WithError(err).Warn("Write failed") + } } -func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler { +func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, cancel := context.WithTimeout(r.Context(), timeout) defer cancel() timeoutWriter := &timeoutResponseWriter{ - w: w, - ctx: ctx, + header: make(http.Header), } + panicChan := make(chan any, 1) + serverDone := make(chan struct{}) go func() { - <-ctx.Done() + defer func() { + if p := recover(); p != nil { + panicChan <- p + } + }() + next.ServeHTTP(timeoutWriter, r.WithContext(ctx)) + close(serverDone) + }() + + select { + case p := <-panicChan: + panic(p) + + case <-serverDone: + timeoutWriter.finallyWrite(w) + + case <-ctx.Done(): err := ctx.Err() if err == context.DeadlineExceeded { - timeoutWriter.mu.Lock() - defer timeoutWriter.mu.Unlock() - if timeoutWriter.wrote == 0 { - // writer wasn't written to, so we're sending the error payload - - httpError := &HTTPError{ - HTTPStatus: http.StatusGatewayTimeout, - ErrorCode: ErrorCodeRequestTimeout, - Message: "Processing this request timed out, please retry after a moment.", - } + httpError := &HTTPError{ + HTTPStatus: http.StatusGatewayTimeout, + ErrorCode: ErrorCodeRequestTimeout, + Message: "Processing this request timed out, please retry after a moment.", + } - httpError = httpError.WithInternalError(err) + httpError = httpError.WithInternalError(err) - HandleResponseError(httpError, w, r) - } - } - }() + HandleResponseError(httpError, w, r) + } else { + // unrecognized context error, so we should wait for the server to finish + // and write out the response + <-serverDone - next.ServeHTTP(timeoutWriter, r.WithContext(ctx)) + timeoutWriter.finallyWrite(w) + } + } }) } } diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index 8c91f0258..a9d908c32 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -319,7 +319,7 @@ func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() { req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) w := httptest.NewRecorder() - timeoutHandler := ts.API.timeoutMiddleware(ts.Config.API.MaxRequestDuration) + timeoutHandler := timeoutMiddleware(ts.Config.API.MaxRequestDuration) slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Sleep for 1 second to simulate a slow handler which should trigger the timeout @@ -335,3 +335,24 @@ func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() { require.Equal(ts.T(), float64(504), data["code"]) require.NotNil(ts.T(), data["msg"]) } + +func TestTimeoutResponseWriter(t *testing.T) { + // timeoutResponseWriter should exhitbit a similar behavior as http.ResponseWriter + req := httptest.NewRequest(http.MethodGet, "http://localhost", nil) + w1 := httptest.NewRecorder() + w2 := httptest.NewRecorder() + + timeoutHandler := timeoutMiddleware(time.Second * 10) + + redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // tries to redirect twice + http.Redirect(w, r, "http://localhost:3001/#message=first_message", http.StatusSeeOther) + + // overwrites the first + http.Redirect(w, r, "http://localhost:3001/second", http.StatusSeeOther) + }) + timeoutHandler(redirectHandler).ServeHTTP(w1, req) + redirectHandler.ServeHTTP(w2, req) + + require.Equal(t, w1.Result(), w2.Result()) +} diff --git a/internal/api/verify.go b/internal/api/verify.go index 8a857f685..5badfc77e 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -125,6 +125,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa err error token *AccessTokenResponse authCode string + rurl string ) grantParams.FillGrantParams(r) @@ -138,6 +139,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa return err } } + err = db.Transaction(func(tx *storage.Connection) error { var terr error user, terr = a.verifyTokenHash(tx, params) @@ -152,12 +154,11 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa case mail.EmailChangeVerification: user, terr = a.emailChangeVerify(r, tx, params, user) if user == nil && terr == nil { - // when double confirmation is required - rurl, err := a.prepRedirectURL(singleConfirmationAccepted, params.RedirectTo, flowType) - if err != nil { - return err + // only one OTP is confirmed at this point, so we return early and ask the user to confirm the second OTP + rurl, terr = a.prepRedirectURL(singleConfirmationAccepted, params.RedirectTo, flowType) + if terr != nil { + return terr } - http.Redirect(w, r, rurl, http.StatusSeeOther) return nil } default: @@ -198,15 +199,17 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa if err != nil { var herr *HTTPError if errors.As(err, &herr) { - rurl, err := a.prepErrorRedirectURL(herr, r, params.RedirectTo, flowType) + rurl, err = a.prepErrorRedirectURL(herr, r, params.RedirectTo, flowType) if err != nil { return err } - http.Redirect(w, r, rurl, http.StatusSeeOther) - return nil } } - rurl := params.RedirectTo + if rurl != "" { + http.Redirect(w, r, rurl, http.StatusSeeOther) + return nil + } + rurl = params.RedirectTo if isImplicitFlow(flowType) && token != nil { q := url.Values{} q.Set("type", params.Type) diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index 73386bebe..8d818b43a 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -168,110 +168,112 @@ func (ts *VerifyTestSuite) TestVerifySecureEmailChange() { } for _, c := range cases { - u, err := models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - // reset user - u.EmailChangeSentAt = nil - u.EmailChangeTokenCurrent = "" - u.EmailChangeTokenNew = "" - require.NoError(ts.T(), ts.API.db.Update(u)) - require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) - - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - - // Setup request - req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Generate access token for request and a mock session - var token string - session, err := models.NewSession(u.ID, nil) - require.NoError(ts.T(), err) - require.NoError(ts.T(), ts.API.db.Create(session)) - - token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.MagicLink) - require.NoError(ts.T(), err) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusOK, w.Code) - - u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - currentTokenHash := u.EmailChangeTokenCurrent - newTokenHash := u.EmailChangeTokenNew - - u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - assert.WithinDuration(ts.T(), time.Now(), *u.EmailChangeSentAt, 1*time.Second) - assert.False(ts.T(), u.IsConfirmed()) - - // Verify new email - reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, newTokenHash) - req = httptest.NewRequest(http.MethodGet, reqURL, nil) - - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - - require.Equal(ts.T(), http.StatusSeeOther, w.Code) - urlVal, err := url.Parse(w.Result().Header.Get("Location")) - ts.Require().NoError(err, "redirect url parse failed") - var v url.Values - if !c.isPKCE { - v, err = url.ParseQuery(urlVal.Fragment) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("message")) - } else if c.isPKCE { - v, err = url.ParseQuery(urlVal.RawQuery) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("message")) - - v, err = url.ParseQuery(urlVal.Fragment) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("message")) - } - - u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - assert.Equal(ts.T(), singleConfirmation, u.EmailChangeConfirmStatus) - - // Verify old email - reqURL = fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, currentTokenHash) - req = httptest.NewRequest(http.MethodGet, reqURL, nil) - - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusSeeOther, w.Code) - - urlVal, err = url.Parse(w.Header().Get("Location")) - ts.Require().NoError(err, "redirect url parse failed") - if !c.isPKCE { - v, err = url.ParseQuery(urlVal.Fragment) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("access_token")) - ts.Require().NotEmpty(v.Get("expires_in")) - ts.Require().NotEmpty(v.Get("refresh_token")) - } else if c.isPKCE { - v, err = url.ParseQuery(urlVal.RawQuery) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("code")) - } - - // user's email should've been updated to newEmail - u, err = models.FindUserByEmailAndAudience(ts.API.db, c.newEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - require.Equal(ts.T(), zeroConfirmation, u.EmailChangeConfirmStatus) - - // Reset confirmation status after each test - u.EmailConfirmedAt = nil - require.NoError(ts.T(), ts.API.db.Update(u)) + ts.Run(c.desc, func() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // reset user + u.EmailChangeSentAt = nil + u.EmailChangeTokenCurrent = "" + u.EmailChangeTokenNew = "" + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + // Setup request + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Generate access token for request and a mock session + var token string + session, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(session)) + + token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.MagicLink) + require.NoError(ts.T(), err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + currentTokenHash := u.EmailChangeTokenCurrent + newTokenHash := u.EmailChangeTokenNew + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + assert.WithinDuration(ts.T(), time.Now(), *u.EmailChangeSentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) + + // Verify new email + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, newTokenHash) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusSeeOther, w.Code) + urlVal, err := url.Parse(w.Result().Header.Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + var v url.Values + if !c.isPKCE { + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + } else if c.isPKCE { + v, err = url.ParseQuery(urlVal.RawQuery) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + } + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), singleConfirmation, u.EmailChangeConfirmStatus) + + // Verify old email + reqURL = fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, currentTokenHash) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusSeeOther, w.Code) + + urlVal, err = url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + if !c.isPKCE { + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("access_token")) + ts.Require().NotEmpty(v.Get("expires_in")) + ts.Require().NotEmpty(v.Get("refresh_token")) + } else if c.isPKCE { + v, err = url.ParseQuery(urlVal.RawQuery) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("code")) + } + + // user's email should've been updated to newEmail + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.newEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.Equal(ts.T(), zeroConfirmation, u.EmailChangeConfirmStatus) + + // Reset confirmation status after each test + u.EmailConfirmedAt = nil + require.NoError(ts.T(), ts.API.db.Update(u)) + }) } }