Skip to content

Commit

Permalink
feat: new timeout writer implementation (supabase#1584)
Browse files Browse the repository at this point in the history
supabase#1529 introduced timeout middleware, but it appears from working in the
wild it has some race conditions that are not particularly helpful.

This PR rewrites the implementation to get rid of race conditions, at
the expense of slightly higher RAM usage. It follows the implementation
of `http.TimeoutHandler` closely.

---------

Co-authored-by: Kang Ming <kang.ming1996@gmail.com>
  • Loading branch information
2 people authored and LashaJini committed Nov 15, 2024
1 parent 17f8b16 commit 259c77c
Show file tree
Hide file tree
Showing 5 changed files with 224 additions and 170 deletions.
2 changes: 1 addition & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
r.UseBypass(recoverer)

if globalConfig.API.MaxRequestDuration > 0 {
r.UseBypass(api.timeoutMiddleware(globalConfig.API.MaxRequestDuration))
r.UseBypass(timeoutMiddleware(globalConfig.API.MaxRequestDuration))
}

// request tracing should be added only when tracing or metrics is enabled
Expand Down
138 changes: 83 additions & 55 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"time"

"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/security"
Expand Down Expand Up @@ -263,95 +264,122 @@ func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.H
}
}

// timeoutResponseWriter is a http.ResponseWriter that prevents subsequent
// writes after the context contained in it has exceeded the deadline. If a
// partial write occurs before the deadline is exceeded, but the writing is not
// complete it will allow further writes.
// timeoutResponseWriter is a http.ResponseWriter that queues up a response
// body to be sent if the serving completes before the context has exceeded its
// deadline.
type timeoutResponseWriter struct {
ctx context.Context
w http.ResponseWriter
wrote int32
mu sync.Mutex
sync.Mutex

header http.Header
wroteHeader bool
snapHeader http.Header // snapshot of the header at the time WriteHeader was called
statusCode int
buf bytes.Buffer
}

func (t *timeoutResponseWriter) Header() http.Header {
t.mu.Lock()
defer t.mu.Unlock()
return t.w.Header()
t.Lock()
defer t.Unlock()

return t.header
}

func (t *timeoutResponseWriter) Write(bytes []byte) (int, error) {
t.mu.Lock()
defer t.mu.Unlock()
if t.ctx.Err() == context.DeadlineExceeded {
if atomic.LoadInt32(&t.wrote) == 0 {
return 0, context.DeadlineExceeded
}
t.Lock()
defer t.Unlock()

// writing started before the deadline exceeded, but the
// deadline came in the middle, so letting the writes go
// through
if !t.wroteHeader {
t.WriteHeader(http.StatusOK)
}

t.wrote = 1

return t.w.Write(bytes)
return t.buf.Write(bytes)
}

func (t *timeoutResponseWriter) WriteHeader(statusCode int) {
t.mu.Lock()
defer t.mu.Unlock()
if t.ctx.Err() == context.DeadlineExceeded {
if atomic.LoadInt32(&t.wrote) == 0 {
return
}
t.Lock()
defer t.Unlock()

if t.wroteHeader {
// ignore multiple calls to WriteHeader
// once WriteHeader has been called once, a snapshot of the header map is taken
// and saved in snapHeader to be used in finallyWrite
return
}
t.statusCode = statusCode
t.wroteHeader = true
t.snapHeader = t.header.Clone()
}

func (t *timeoutResponseWriter) finallyWrite(w http.ResponseWriter) {
t.Lock()
defer t.Unlock()

// writing started before the deadline exceeded, but the
// deadline came in the middle, so letting the writes go
// through
dst := w.Header()
for k, vv := range t.snapHeader {
dst[k] = vv
}

t.wrote = 1
if !t.wroteHeader {
t.statusCode = http.StatusOK
}

t.w.WriteHeader(statusCode)
w.WriteHeader(t.statusCode)
if _, err := w.Write(t.buf.Bytes()); err != nil {
logrus.WithError(err).Warn("Write failed")
}
}

func (a *API) timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
func timeoutMiddleware(timeout time.Duration) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, cancel := context.WithTimeout(r.Context(), timeout)
defer cancel()

timeoutWriter := &timeoutResponseWriter{
w: w,
ctx: ctx,
header: make(http.Header),
}

panicChan := make(chan any, 1)
serverDone := make(chan struct{})
go func() {
<-ctx.Done()
defer func() {
if p := recover(); p != nil {
panicChan <- p
}
}()

next.ServeHTTP(timeoutWriter, r.WithContext(ctx))
close(serverDone)
}()

select {
case p := <-panicChan:
panic(p)

case <-serverDone:
timeoutWriter.finallyWrite(w)

case <-ctx.Done():
err := ctx.Err()

if err == context.DeadlineExceeded {
timeoutWriter.mu.Lock()
defer timeoutWriter.mu.Unlock()
if timeoutWriter.wrote == 0 {
// writer wasn't written to, so we're sending the error payload

httpError := &HTTPError{
HTTPStatus: http.StatusGatewayTimeout,
ErrorCode: ErrorCodeRequestTimeout,
Message: "Processing this request timed out, please retry after a moment.",
}
httpError := &HTTPError{
HTTPStatus: http.StatusGatewayTimeout,
ErrorCode: ErrorCodeRequestTimeout,
Message: "Processing this request timed out, please retry after a moment.",
}

httpError = httpError.WithInternalError(err)
httpError = httpError.WithInternalError(err)

HandleResponseError(httpError, w, r)
}
}
}()
HandleResponseError(httpError, w, r)
} else {
// unrecognized context error, so we should wait for the server to finish
// and write out the response
<-serverDone

next.ServeHTTP(timeoutWriter, r.WithContext(ctx))
timeoutWriter.finallyWrite(w)
}
}
})
}
}
23 changes: 22 additions & 1 deletion internal/api/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() {
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
w := httptest.NewRecorder()

timeoutHandler := ts.API.timeoutMiddleware(ts.Config.API.MaxRequestDuration)
timeoutHandler := timeoutMiddleware(ts.Config.API.MaxRequestDuration)

slowHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Sleep for 1 second to simulate a slow handler which should trigger the timeout
Expand All @@ -335,3 +335,24 @@ func (ts *MiddlewareTestSuite) TestTimeoutMiddleware() {
require.Equal(ts.T(), float64(504), data["code"])
require.NotNil(ts.T(), data["msg"])
}

func TestTimeoutResponseWriter(t *testing.T) {
// timeoutResponseWriter should exhitbit a similar behavior as http.ResponseWriter
req := httptest.NewRequest(http.MethodGet, "http://localhost", nil)
w1 := httptest.NewRecorder()
w2 := httptest.NewRecorder()

timeoutHandler := timeoutMiddleware(time.Second * 10)

redirectHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// tries to redirect twice
http.Redirect(w, r, "http://localhost:3001/#message=first_message", http.StatusSeeOther)

// overwrites the first
http.Redirect(w, r, "http://localhost:3001/second", http.StatusSeeOther)
})
timeoutHandler(redirectHandler).ServeHTTP(w1, req)
redirectHandler.ServeHTTP(w2, req)

require.Equal(t, w1.Result(), w2.Result())
}
21 changes: 12 additions & 9 deletions internal/api/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
err error
token *AccessTokenResponse
authCode string
rurl string
)

grantParams.FillGrantParams(r)
Expand All @@ -138,6 +139,7 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
return err
}
}

err = db.Transaction(func(tx *storage.Connection) error {
var terr error
user, terr = a.verifyTokenHash(tx, params)
Expand All @@ -152,12 +154,11 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
case mail.EmailChangeVerification:
user, terr = a.emailChangeVerify(r, tx, params, user)
if user == nil && terr == nil {
// when double confirmation is required
rurl, err := a.prepRedirectURL(singleConfirmationAccepted, params.RedirectTo, flowType)
if err != nil {
return err
// only one OTP is confirmed at this point, so we return early and ask the user to confirm the second OTP
rurl, terr = a.prepRedirectURL(singleConfirmationAccepted, params.RedirectTo, flowType)
if terr != nil {
return terr
}
http.Redirect(w, r, rurl, http.StatusSeeOther)
return nil
}
default:
Expand Down Expand Up @@ -198,15 +199,17 @@ func (a *API) verifyGet(w http.ResponseWriter, r *http.Request, params *VerifyPa
if err != nil {
var herr *HTTPError
if errors.As(err, &herr) {
rurl, err := a.prepErrorRedirectURL(herr, r, params.RedirectTo, flowType)
rurl, err = a.prepErrorRedirectURL(herr, r, params.RedirectTo, flowType)
if err != nil {
return err
}
http.Redirect(w, r, rurl, http.StatusSeeOther)
return nil
}
}
rurl := params.RedirectTo
if rurl != "" {
http.Redirect(w, r, rurl, http.StatusSeeOther)
return nil
}
rurl = params.RedirectTo
if isImplicitFlow(flowType) && token != nil {
q := url.Values{}
q.Set("type", params.Type)
Expand Down
Loading

0 comments on commit 259c77c

Please sign in to comment.