Skip to content

Commit

Permalink
fix: refactor mfa and aal update methods (supabase#1503)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

Regular cleanup.

- Merge the AAL and AMR update methods. Think this was suggested a while
back but only getting to it now.
- Move `FindUserByFactors` to tests - a user should never need to find a
user by factors. They can fetch it from the models.


Ideally I'd remove it entirely but it's used in a few tests which will
take significant time to refactor.
  • Loading branch information
J0 authored and LashaJini committed Nov 13, 2024
1 parent b27c5ed commit 864809e
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 43 deletions.
6 changes: 1 addition & 5 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,7 @@ func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) erro
func (a *API) adminUserGetFactors(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
user := getUser(ctx)
factors, terr := models.FindFactorsByUser(a.db, user)
if terr != nil {
return terr
}
return sendJSON(w, http.StatusOK, factors)
return sendJSON(w, http.StatusOK, user.Factors)
}

// adminUserUpdate updates a single factor object
Expand Down
33 changes: 24 additions & 9 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ import (

"github.com/gofrs/uuid"

"database/sql"
"github.com/pkg/errors"
"github.com/pquerna/otp"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/utilities"

"github.com/jackc/pgx/v4"
Expand Down Expand Up @@ -143,7 +146,7 @@ func (ts *MFATestSuite) TestEnrollFactor() {
ts.Run(c.desc, func() {
w := performEnrollFlow(ts, token, c.friendlyName, c.factorType, c.issuer, c.expectedCode)

factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
factors, err := FindFactorsByUser(ts.API.db, ts.TestUser)
ts.Require().NoError(err)
addedFactor := factors[len(factors)-1]
require.False(ts.T(), addedFactor.IsVerified())
Expand Down Expand Up @@ -194,15 +197,15 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() {
}

// All Factors except last factor should be expired
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
factors, err := FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err)

// Make a challenge so last, unverified factor isn't deleted on next enroll (Factor 2)
_ = performChallengeFlow(ts, factors[len(factors)-1].ID, token)

// Enroll another Factor (Factor 3)
_ = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", http.StatusOK)
factors, err = models.FindFactorsByUser(ts.API.db, ts.TestUser)
factors, err = FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err)
require.Equal(ts.T(), 3, len(factors))
}
Expand Down Expand Up @@ -248,7 +251,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
require.NoError(ts.T(), err)

sharedSecret := ts.TestOTPKey.Secret()
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
factors, err := FindFactorsByUser(ts.API.db, ts.TestUser)
f := factors[0]
f.Secret = sharedSecret
require.NoError(ts.T(), err)
Expand Down Expand Up @@ -319,17 +322,17 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
for _, v := range cases {
ts.Run(v.desc, func() {
var buffer bytes.Buffer
if v.isAAL2 {
ts.TestSession.UpdateAssociatedAAL(ts.API.db, models.AAL2.String())
}

// Create Session to test behaviour which downgrades other sessions
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
factors, err := FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err, "error finding factors")
f := factors[0]
f.Secret = ts.TestOTPKey.Secret()
require.NoError(ts.T(), f.UpdateStatus(ts.API.db, models.FactorStateVerified))
require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor")

if v.isAAL2 {
ts.TestSession.UpdateAALAndAssociatedFactor(ts.API.db, models.AAL2, &f.ID)
}
token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)
w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer)
require.Equal(ts.T(), v.expectedHTTPCode, w.Code)
Expand Down Expand Up @@ -659,3 +662,15 @@ func cleanupHook(ts *MFATestSuite, hookName string) {
err := ts.API.db.RawQuery(cleanupHookSQL).Exec()
require.NoError(ts.T(), err)
}

// FindFactorsByUser returns all factors belonging to a user ordered by timestamp. Don't use this outside of tests.
func FindFactorsByUser(tx *storage.Connection, user *models.User) ([]*models.Factor, error) {
factors := []*models.Factor{}
if err := tx.Q().Where("user_id = ?", user.ID).Order("created_at asc").All(&factors); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return factors, nil
}
return nil, errors.Wrap(err, "Database error when finding MFA factors associated to user")
}
return factors, nil
}
9 changes: 3 additions & 6 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)

func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) {
config := a.config
aal, amr := models.AAL1.String(), []models.AMREntry{}
aal, amr := models.AAL1, []models.AMREntry{}
sid := ""
if sessionId != nil {
sid = sessionId.String()
Expand Down Expand Up @@ -324,7 +324,7 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u
UserMetaData: user.UserMetaData,
Role: user.Role,
SessionId: sid,
AuthenticatorAssuranceLevel: aal,
AuthenticatorAssuranceLevel: aal.String(),
AuthenticationMethodReference: amr,
IsAnonymous: user.IsAnonymous,
}
Expand Down Expand Up @@ -452,10 +452,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection,
return terr
}

if err := session.UpdateAssociatedFactor(tx, grantParams.FactorID); err != nil {
return err
}
if err := session.UpdateAssociatedAAL(tx, aal); err != nil {
if err := session.UpdateAALAndAssociatedFactor(tx, aal, grantParams.FactorID); err != nil {
return err
}

Expand Down
12 changes: 0 additions & 12 deletions internal/models/factor.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,6 @@ func NewFactor(user *User, friendlyName string, factorType string, state FactorS
return factor
}

// FindFactorsByUser returns all factors belonging to a user ordered by timestamp
func FindFactorsByUser(tx *storage.Connection, user *User) ([]*Factor, error) {
factors := []*Factor{}
if err := tx.Q().Where("user_id = ?", user.ID).Order("created_at asc").All(&factors); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return factors, nil
}
return nil, errors.Wrap(err, "Database error when finding MFA factors associated to user")
}
return factors, nil
}

func FindFactorByFactorID(conn *storage.Connection, factorID uuid.UUID) (*Factor, error) {
var factor Factor
err := conn.Find(&factor, factorID)
Expand Down
17 changes: 7 additions & 10 deletions internal/models/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,21 +272,18 @@ func LogoutAllExceptMe(tx *storage.Connection, sessionId uuid.UUID, userID uuid.
return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE id != ? AND user_id = ?", sessionId, userID).Exec()
}

func (s *Session) UpdateAssociatedFactor(tx *storage.Connection, factorID *uuid.UUID) error {
func (s *Session) UpdateAALAndAssociatedFactor(tx *storage.Connection, aal AuthenticatorAssuranceLevel, factorID *uuid.UUID) error {
s.FactorID = factorID
return tx.Update(s)
aalAsString := aal.String()
s.AAL = &aalAsString
return tx.UpdateOnly(s, "aal", "factor_id")
}

func (s *Session) UpdateAssociatedAAL(tx *storage.Connection, aal string) error {
s.AAL = &aal
return tx.Update(s)
}

func (s *Session) CalculateAALAndAMR(user *User) (aal string, amr []AMREntry, err error) {
amr, aal = []AMREntry{}, AAL1.String()
func (s *Session) CalculateAALAndAMR(user *User) (aal AuthenticatorAssuranceLevel, amr []AMREntry, err error) {
amr, aal = []AMREntry{}, AAL1
for _, claim := range s.AMRClaims {
if *claim.AuthenticationMethod == TOTPSignIn.String() {
aal = AAL2.String()
aal = AAL2
}
amr = append(amr, AMREntry{Method: claim.GetAuthenticationMethod(), Timestamp: claim.UpdatedAt.Unix()})
}
Expand Down
2 changes: 1 addition & 1 deletion internal/models/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (ts *SessionsTestSuite) TestCalculateAALAndAMR() {
aal, amr, err := session.CalculateAALAndAMR(u)
require.NoError(ts.T(), err)

require.Equal(ts.T(), AAL2.String(), aal)
require.Equal(ts.T(), AAL2, aal)
require.Equal(ts.T(), totalDistinctClaims, len(amr))

found := false
Expand Down

0 comments on commit 864809e

Please sign in to comment.