Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor mfa and aal update methods #1503

Merged
merged 3 commits into from
Mar 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -545,11 +545,7 @@ func (a *API) adminUserDeleteFactor(w http.ResponseWriter, r *http.Request) erro
func (a *API) adminUserGetFactors(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
user := getUser(ctx)
factors, terr := models.FindFactorsByUser(a.db, user)
if terr != nil {
return terr
}
return sendJSON(w, http.StatusOK, factors)
return sendJSON(w, http.StatusOK, user.Factors)
}

// adminUserUpdate updates a single factor object
Expand Down
33 changes: 24 additions & 9 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@ import (

"github.com/gofrs/uuid"

"database/sql"
"github.com/pkg/errors"
"github.com/pquerna/otp"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
"github.com/supabase/auth/internal/utilities"

"github.com/jackc/pgx/v4"
Expand Down Expand Up @@ -143,7 +146,7 @@ func (ts *MFATestSuite) TestEnrollFactor() {
ts.Run(c.desc, func() {
w := performEnrollFlow(ts, token, c.friendlyName, c.factorType, c.issuer, c.expectedCode)

factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
factors, err := FindFactorsByUser(ts.API.db, ts.TestUser)
ts.Require().NoError(err)
addedFactor := factors[len(factors)-1]
require.False(ts.T(), addedFactor.IsVerified())
Expand Down Expand Up @@ -194,15 +197,15 @@ func (ts *MFATestSuite) TestMultipleEnrollsCleanupExpiredFactors() {
}

// All Factors except last factor should be expired
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
factors, err := FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err)

// Make a challenge so last, unverified factor isn't deleted on next enroll (Factor 2)
_ = performChallengeFlow(ts, factors[len(factors)-1].ID, token)

// Enroll another Factor (Factor 3)
_ = performEnrollFlow(ts, token, "", models.TOTP, "https://issuer.com", http.StatusOK)
factors, err = models.FindFactorsByUser(ts.API.db, ts.TestUser)
factors, err = FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err)
require.Equal(ts.T(), 3, len(factors))
}
Expand Down Expand Up @@ -248,7 +251,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
require.NoError(ts.T(), err)

sharedSecret := ts.TestOTPKey.Secret()
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
factors, err := FindFactorsByUser(ts.API.db, ts.TestUser)
f := factors[0]
f.Secret = sharedSecret
require.NoError(ts.T(), err)
Expand Down Expand Up @@ -319,17 +322,17 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
for _, v := range cases {
ts.Run(v.desc, func() {
var buffer bytes.Buffer
if v.isAAL2 {
ts.TestSession.UpdateAssociatedAAL(ts.API.db, models.AAL2.String())
}

// Create Session to test behaviour which downgrades other sessions
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
factors, err := FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err, "error finding factors")
f := factors[0]
f.Secret = ts.TestOTPKey.Secret()
require.NoError(ts.T(), f.UpdateStatus(ts.API.db, models.FactorStateVerified))
require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor")

if v.isAAL2 {
ts.TestSession.UpdateAALAndAssociatedFactor(ts.API.db, models.AAL2, &f.ID)
}
token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)
w := ServeAuthenticatedRequest(ts, http.MethodDelete, fmt.Sprintf("/factors/%s", f.ID), token, buffer)
require.Equal(ts.T(), v.expectedHTTPCode, w.Code)
Expand Down Expand Up @@ -659,3 +662,15 @@ func cleanupHook(ts *MFATestSuite, hookName string) {
err := ts.API.db.RawQuery(cleanupHookSQL).Exec()
require.NoError(ts.T(), err)
}

// FindFactorsByUser returns all factors belonging to a user ordered by timestamp. Don't use this outside of tests.
func FindFactorsByUser(tx *storage.Connection, user *models.User) ([]*models.Factor, error) {
factors := []*models.Factor{}
if err := tx.Q().Where("user_id = ?", user.ID).Order("created_at asc").All(&factors); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return factors, nil
}
return nil, errors.Wrap(err, "Database error when finding MFA factors associated to user")
}
return factors, nil
}
9 changes: 3 additions & 6 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ func (a *API) PKCE(ctx context.Context, w http.ResponseWriter, r *http.Request)

func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, user *models.User, sessionId *uuid.UUID, authenticationMethod models.AuthenticationMethod) (string, int64, error) {
config := a.config
aal, amr := models.AAL1.String(), []models.AMREntry{}
aal, amr := models.AAL1, []models.AMREntry{}
sid := ""
if sessionId != nil {
sid = sessionId.String()
Expand Down Expand Up @@ -324,7 +324,7 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u
UserMetaData: user.UserMetaData,
Role: user.Role,
SessionId: sid,
AuthenticatorAssuranceLevel: aal,
AuthenticatorAssuranceLevel: aal.String(),
AuthenticationMethodReference: amr,
IsAnonymous: user.IsAnonymous,
}
Expand Down Expand Up @@ -452,10 +452,7 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection,
return terr
}

if err := session.UpdateAssociatedFactor(tx, grantParams.FactorID); err != nil {
return err
}
if err := session.UpdateAssociatedAAL(tx, aal); err != nil {
if err := session.UpdateAALAndAssociatedFactor(tx, aal, grantParams.FactorID); err != nil {
return err
}

Expand Down
12 changes: 0 additions & 12 deletions internal/models/factor.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,18 +141,6 @@ func NewFactor(user *User, friendlyName string, factorType string, state FactorS
return factor
}

// FindFactorsByUser returns all factors belonging to a user ordered by timestamp
func FindFactorsByUser(tx *storage.Connection, user *User) ([]*Factor, error) {
factors := []*Factor{}
if err := tx.Q().Where("user_id = ?", user.ID).Order("created_at asc").All(&factors); err != nil {
if errors.Cause(err) == sql.ErrNoRows {
return factors, nil
}
return nil, errors.Wrap(err, "Database error when finding MFA factors associated to user")
}
return factors, nil
}

func FindFactorByFactorID(conn *storage.Connection, factorID uuid.UUID) (*Factor, error) {
var factor Factor
err := conn.Find(&factor, factorID)
Expand Down
17 changes: 7 additions & 10 deletions internal/models/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,21 +272,18 @@ func LogoutAllExceptMe(tx *storage.Connection, sessionId uuid.UUID, userID uuid.
return tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Session{}}).TableName()+" WHERE id != ? AND user_id = ?", sessionId, userID).Exec()
}

func (s *Session) UpdateAssociatedFactor(tx *storage.Connection, factorID *uuid.UUID) error {
func (s *Session) UpdateAALAndAssociatedFactor(tx *storage.Connection, aal AuthenticatorAssuranceLevel, factorID *uuid.UUID) error {
s.FactorID = factorID
return tx.Update(s)
aalAsString := aal.String()
s.AAL = &aalAsString
return tx.UpdateOnly(s, "aal", "factor_id")
}

func (s *Session) UpdateAssociatedAAL(tx *storage.Connection, aal string) error {
s.AAL = &aal
return tx.Update(s)
}

func (s *Session) CalculateAALAndAMR(user *User) (aal string, amr []AMREntry, err error) {
amr, aal = []AMREntry{}, AAL1.String()
func (s *Session) CalculateAALAndAMR(user *User) (aal AuthenticatorAssuranceLevel, amr []AMREntry, err error) {
amr, aal = []AMREntry{}, AAL1
for _, claim := range s.AMRClaims {
if *claim.AuthenticationMethod == TOTPSignIn.String() {
aal = AAL2.String()
aal = AAL2
}
amr = append(amr, AMREntry{Method: claim.GetAuthenticationMethod(), Timestamp: claim.UpdatedAt.Unix()})
}
Expand Down
2 changes: 1 addition & 1 deletion internal/models/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (ts *SessionsTestSuite) TestCalculateAALAndAMR() {
aal, amr, err := session.CalculateAALAndAMR(u)
require.NoError(ts.T(), err)

require.Equal(ts.T(), AAL2.String(), aal)
require.Equal(ts.T(), AAL2, aal)
require.Equal(ts.T(), totalDistinctClaims, len(amr))

found := false
Expand Down
Loading