diff --git a/internal/api/middleware.go b/internal/api/middleware.go index ff821e14a..d387936ff 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -11,6 +11,7 @@ import ( "sync" "time" + chimiddleware "github.com/go-chi/chi/v5/middleware" "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/observability" @@ -245,15 +246,18 @@ func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Reque return ctx, nil } -func (a *API) databaseCleanup(cleanup *models.Cleanup) func(http.Handler) http.Handler { +func (a *API) databaseCleanup(cleanup models.Cleaner) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - next.ServeHTTP(w, r) - + wrappedResp := chimiddleware.NewWrapResponseWriter(w, r.ProtoMajor) + next.ServeHTTP(wrappedResp, r) switch r.Method { case http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete: + if (wrappedResp.Status() / 100) != 2 { + // don't do any cleanups for non-2xx responses + return + } // continue - default: return } diff --git a/internal/api/middleware_test.go b/internal/api/middleware_test.go index ac41de9c2..98dd6a8d6 100644 --- a/internal/api/middleware_test.go +++ b/internal/api/middleware_test.go @@ -14,9 +14,11 @@ import ( "github.com/didip/tollbooth/v5/limiter" jwt "github.com/golang-jwt/jwt/v5" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/storage" ) const ( @@ -443,3 +445,66 @@ func (ts *MiddlewareTestSuite) TestLimitHandler() { ts.API.limitHandler(lmt).handler(okHandler).ServeHTTP(w, req) require.Equal(ts.T(), http.StatusTooManyRequests, w.Code) } + +type MockCleanup struct { + mock.Mock +} + +func (m *MockCleanup) Clean(db *storage.Connection) (int, error) { + m.Called(db) + return 0, nil +} + +func (ts *MiddlewareTestSuite) TestDatabaseCleanup() { + testHandler := func(statusCode int) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(statusCode) + b, _ := json.Marshal(map[string]interface{}{"message": "ok"}) + w.Write([]byte(b)) + }) + } + + cases := []struct { + desc string + statusCode int + method string + }{ + { + desc: "Run cleanup successfully", + statusCode: http.StatusOK, + method: http.MethodPost, + }, + { + desc: "Skip cleanup if GET", + statusCode: http.StatusOK, + method: http.MethodGet, + }, + { + desc: "Skip cleanup if 3xx", + statusCode: http.StatusSeeOther, + method: http.MethodPost, + }, + { + desc: "Skip cleanup if 4xx", + statusCode: http.StatusBadRequest, + method: http.MethodPost, + }, + { + desc: "Skip cleanup if 5xx", + statusCode: http.StatusInternalServerError, + method: http.MethodPost, + }, + } + + mockCleanup := new(MockCleanup) + mockCleanup.On("Clean", mock.Anything).Return(0, nil) + for _, c := range cases { + ts.Run("DatabaseCleanup", func() { + req := httptest.NewRequest(c.method, "http://localhost", nil) + w := httptest.NewRecorder() + ts.API.databaseCleanup(mockCleanup)(testHandler(c.statusCode)).ServeHTTP(w, req) + require.Equal(ts.T(), c.statusCode, w.Code) + }) + } + mockCleanup.AssertNumberOfCalls(ts.T(), "Clean", 1) +} diff --git a/internal/models/cleanup.go b/internal/models/cleanup.go index 69cf7c7a3..9669c8d4b 100644 --- a/internal/models/cleanup.go +++ b/internal/models/cleanup.go @@ -16,6 +16,10 @@ import ( "github.com/supabase/auth/internal/storage" ) +type Cleaner interface { + Clean(*storage.Connection) (int, error) +} + type Cleanup struct { cleanupStatements []string