From 864809ee3e505d54240d688b2126820e1873d125 Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Thu, 28 Mar 2024 13:19:10 +0700 Subject: [PATCH] fix: refactor mfa and aal update methods (#1503) ## What kind of change does this PR introduce? Regular cleanup. - Merge the AAL and AMR update methods. Think this was suggested a while back but only getting to it now. - Move `FindUserByFactors` to tests - a user should never need to find a user by factors. They can fetch it from the models. Ideally I'd remove it entirely but it's used in a few tests which will take significant time to refactor. --- internal/api/admin.go | 6 +----- internal/api/mfa_test.go | 33 +++++++++++++++++++++++--------- internal/api/token.go | 9 +++------ internal/models/factor.go | 12 ------------ internal/models/sessions.go | 17 +++++++--------- internal/models/sessions_test.go | 2 +- 6 files changed, 36 insertions(+), 43 deletions(-) diff --git a/internal/api/admin.go b/internal/api/admin.go index f7acf3b45c..b1b8e4d609 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -545,11 +545,7 @@ func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) erro func (a *API) adminUserGetFactors(w http.ResponseWriter, r *http.Request) error { ctx := r.Context() user := getUser(ctx) - factors, terr := models.FindFactorsByUser(a.db, user) - if terr != nil { - return terr - } - return sendJSON(w, http.StatusOK, factors) + return sendJSON(w, http.StatusOK, user.Factors) } // adminUserUpdate updates a single factor object diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index 39ec9f2cc8..216d740437 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -13,9 +13,12 @@ import ( "github.com/gofrs/uuid" + "database/sql" + "github.com/pkg/errors" "github.com/pquerna/otp" "github.com/supabase/auth/internal/conf" "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" "github.com/supabase/auth/internal/utilities" "github.com/jackc/pgx/v4" @@ -143,7 +146,7 @@ func (ts *MFATestSuite) TestEnrollFactor() { ts.Run(c.desc, func() { w := performEnrollFlow(ts, token, c.friendlyName, c.factorType, c.issuer, c.expectedCode) - factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser) + factors, err := FindFactorsByUser(ts.API.db, ts.TestUser) ts.Require().NoError(err) addedFactor := factors[len(factors)-1] require.False(ts.T(), addedFactor.IsVerified()) @@ -194,7 +197,7 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { } // All Factors except last factor should be expired - factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser) + factors, err := FindFactorsByUser(ts.API.db, ts.TestUser) require.NoError(ts.T(), err) // Make a challenge so last, unverified factor isn't deleted on next enroll (Factor 2) @@ -202,7 +205,7 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() { // Enroll another Factor (Factor 3) _ = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", http.StatusOK) - factors, err = models.FindFactorsByUser(ts.API.db, ts.TestUser) + factors, err = FindFactorsByUser(ts.API.db, ts.TestUser) require.NoError(ts.T(), err) require.Equal(ts.T(), 3, len(factors)) } @@ -248,7 +251,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() { require.NoError(ts.T(), err) sharedSecret := ts.TestOTPKey.Secret() - factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser) + factors, err := FindFactorsByUser(ts.API.db, ts.TestUser) f := factors[0] f.Secret = sharedSecret require.NoError(ts.T(), err) @@ -319,17 +322,17 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() { for _, v := range cases { ts.Run(v.desc, func() { var buffer bytes.Buffer - if v.isAAL2 { - ts.TestSession.UpdateAssociatedAAL(ts.API.db, models.AAL2.String()) - } + // Create Session to test behaviour which downgrades other sessions - factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser) + factors, err := FindFactorsByUser(ts.API.db, ts.TestUser) require.NoError(ts.T(), err, "error finding factors") f := factors[0] f.Secret = ts.TestOTPKey.Secret() require.NoError(ts.T(), f.UpdateStatus(ts.API.db, models.FactorStateVerified)) require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor") - + if v.isAAL2 { + ts.TestSession.UpdateAALAndAssociatedFactor(ts.API.db, models.AAL2, &f.ID) + } token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer) require.Equal(ts.T(), v.expectedHTTPCode, w.Code) @@ -659,3 +662,15 @@ func cleanupHook(ts *MFATestSuite, hookName string) { err := ts.API.db.RawQuery(cleanupHookSQL).Exec() require.NoError(ts.T(), err) } + +// FindFactorsByUser returns all factors belonging to a user ordered by timestamp. Don't use this outside of tests. +func FindFactorsByUser(tx *storage.Connection, user *models.User) ([]*models.Factor, error) { + factors := []*models.Factor{} + if err := tx.Q().Where("user_id = ?", user.ID).Order("created_at asc").All(&factors); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return factors, nil + } + return nil, errors.Wrap(err, "Database error when finding MFA factors associated to user") + } + return factors, nil +} diff --git a/internal/api/token.go b/internal/api/token.go index cee79be118..eed530f67d 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -293,7 +293,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request) func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) { config := a.config - aal, amr := models.AAL1.String(), []models.AMREntry{} + aal, amr := models.AAL1, []models.AMREntry{} sid := "" if sessionId != nil { sid = sessionId.String() @@ -324,7 +324,7 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u UserMetaData: user.UserMetaData, Role: user.Role, SessionId: sid, - AuthenticatorAssuranceLevel: aal, + AuthenticatorAssuranceLevel: aal.String(), AuthenticationMethodReference: amr, IsAnonymous: user.IsAnonymous, } @@ -452,10 +452,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection, return terr } - if err := session.UpdateAssociatedFactor(tx, grantParams.FactorID); err != nil { - return err - } - if err := session.UpdateAssociatedAAL(tx, aal); err != nil { + if err := session.UpdateAALAndAssociatedFactor(tx, aal, grantParams.FactorID); err != nil { return err } diff --git a/internal/models/factor.go b/internal/models/factor.go index 24afb77415..2657335648 100644 --- a/internal/models/factor.go +++ b/internal/models/factor.go @@ -141,18 +141,6 @@ func NewFactor(user *User, friendlyName string, factorType string, state FactorS return factor } -// FindFactorsByUser returns all factors belonging to a user ordered by timestamp -func FindFactorsByUser(tx *storage.Connection, user *User) ([]*Factor, error) { - factors := []*Factor{} - if err := tx.Q().Where("user_id = ?", user.ID).Order("created_at asc").All(&factors); err != nil { - if errors.Cause(err) == sql.ErrNoRows { - return factors, nil - } - return nil, errors.Wrap(err, "Database error when finding MFA factors associated to user") - } - return factors, nil -} - func FindFactorByFactorID(conn *storage.Connection, factorID uuid.UUID) (*Factor, error) { var factor Factor err := conn.Find(&factor, factorID) diff --git a/internal/models/sessions.go b/internal/models/sessions.go index 5eae91a4b8..ca4f135da6 100644 --- a/internal/models/sessions.go +++ b/internal/models/sessions.go @@ -272,21 +272,18 @@ func LogoutAllExceptMe(tx *storage.Connection, sessionId uuid.UUID, userID uuid. return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE id != ? AND user_id = ?", sessionId, userID).Exec() } -func (s *Session) UpdateAssociatedFactor(tx *storage.Connection, factorID *uuid.UUID) error { +func (s *Session) UpdateAALAndAssociatedFactor(tx *storage.Connection, aal AuthenticatorAssuranceLevel, factorID *uuid.UUID) error { s.FactorID = factorID - return tx.Update(s) + aalAsString := aal.String() + s.AAL = &aalAsString + return tx.UpdateOnly(s, "aal", "factor_id") } -func (s *Session) UpdateAssociatedAAL(tx *storage.Connection, aal string) error { - s.AAL = &aal - return tx.Update(s) -} - -func (s *Session) CalculateAALAndAMR(user *User) (aal string, amr []AMREntry, err error) { - amr, aal = []AMREntry{}, AAL1.String() +func (s *Session) CalculateAALAndAMR(user *User) (aal AuthenticatorAssuranceLevel, amr []AMREntry, err error) { + amr, aal = []AMREntry{}, AAL1 for _, claim := range s.AMRClaims { if *claim.AuthenticationMethod == TOTPSignIn.String() { - aal = AAL2.String() + aal = AAL2 } amr = append(amr, AMREntry{Method: claim.GetAuthenticationMethod(), Timestamp: claim.UpdatedAt.Unix()}) } diff --git a/internal/models/sessions_test.go b/internal/models/sessions_test.go index 158c4de792..9dce78e953 100644 --- a/internal/models/sessions_test.go +++ b/internal/models/sessions_test.go @@ -84,7 +84,7 @@ func (ts *SessionsTestSuite) TestCalculateAALAndAMR() { aal, amr, err := session.CalculateAALAndAMR(u) require.NoError(ts.T(), err) - require.Equal(ts.T(), AAL2.String(), aal) + require.Equal(ts.T(), AAL2, aal) require.Equal(ts.T(), totalDistinctClaims, len(amr)) found := false