diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go index ce36bb65cb..b400f41f05 100644 --- a/internal/api/errorcodes.go +++ b/internal/api/errorcodes.go @@ -81,4 +81,5 @@ const ( ErrorCodeMFAPhoneVerifyDisabled ErrorCode = "mfa_phone_verify_not_enabled" ErrorCodeMFATOTPEnrollDisabled ErrorCode = "mfa_totp_enroll_not_enabled" ErrorCodeMFATOTPVerifyDisabled ErrorCode = "mfa_totp_verify_not_enabled" + ErrorCodeVerifiedFactorExists ErrorCode = "mfa_verified_factor_exists" ) diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 21a2fbc665..f62a418d4b 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -89,13 +89,37 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params * if err := models.DeleteExpiredFactors(db, config.MFA.FactorExpiryDuration); err != nil { return err } + var factorsToDelete []models.Factor + for _, factor := range user.Factors { + switch { + case factor.FriendlyName == params.FriendlyName: + return unprocessableEntityError( + ErrorCodeMFAFactorNameConflict, + fmt.Sprintf("A factor with the friendly name %q for this user already exists", factor.FriendlyName), + ) + + case factor.IsPhoneFactor(): + if factor.Phone.String() == phone { + if factor.IsVerified() { + return unprocessableEntityError( + ErrorCodeVerifiedFactorExists, + "A verified phone factor already exists, unenroll the existing factor to continue", + ) + } else if factor.IsUnverified() { + factorsToDelete = append(factorsToDelete, factor) + } - for _, factor := range factors { - if factor.IsVerified() { - numVerifiedFactors += 1 + } + + case factor.IsVerified(): + numVerifiedFactors++ } } + if err := db.Destroy(&factorsToDelete); err != nil { + return internalServerError("Database error deleting unverified phone factors").WithInternalError(err) + } + if factorCount >= int(config.MFA.MaxEnrolledFactors) { return unprocessableEntityError(ErrorCodeTooManyEnrolledMFAFactors, "Maximum number of verified factors reached, unenroll to continue") } @@ -110,12 +134,7 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params * factor := models.NewPhoneFactor(user, phone, params.FriendlyName) err = db.Transaction(func(tx *storage.Connection) error { if terr := tx.Create(factor); terr != nil { - pgErr := utilities.NewPostgresError(terr) - if pgErr.IsUniqueConstraintViolated() { - return unprocessableEntityError(ErrorCodeMFAFactorNameConflict, fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) - } return terr - } if terr := models.NewAuditLogEntry(r, tx, user, models.EnrollFactorAction, r.RemoteAddr, map[string]interface{}{ "factor_id": factor.ID, @@ -132,7 +151,7 @@ func (a *API) enrollPhoneFactor(w http.ResponseWriter, r *http.Request, params * ID: factor.ID, Type: models.Phone, FriendlyName: factor.FriendlyName, - Phone: string(factor.Phone), + Phone: params.Phone, }) } @@ -323,7 +342,7 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error return internalServerError("Failed to get SMS provider").WithInternalError(err) } // We omit messageID for now, can consider reinstating if there are requests. - if _, err = smsProvider.SendMessage(string(factor.Phone), message, channel, otp); err != nil { + if _, err = smsProvider.SendMessage(factor.Phone.String(), message, channel, otp); err != nil { return internalServerError("error sending message").WithInternalError(err) } } @@ -417,6 +436,7 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V return internalServerError("Database error finding Challenge").WithInternalError(err) } + // Ambiguous so as not to leak whether there is a verified challenge if challenge.VerifiedAt != nil || challenge.IPAddress != currentIP { return unprocessableEntityError(ErrorCodeMFAIPAddressMismatch, "Challenge and verify IP addresses mismatch") } @@ -485,6 +505,7 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V if terr = models.NewAuditLogEntry(r, tx, user, models.VerifyFactorAction, r.RemoteAddr, map[string]interface{}{ "factor_id": factor.ID, "challenge_id": challenge.ID, + "factor_type": factor.FactorType, }); terr != nil { return terr } @@ -524,7 +545,7 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V if terr = models.InvalidateSessionsWithAALLessThan(tx, user.ID, models.AAL2.String()); terr != nil { return internalServerError("Failed to update sessions. %s", terr) } - if terr = models.DeleteUnverifiedFactors(tx, user); terr != nil { + if terr = models.DeleteUnverifiedFactors(tx, user, factor.FactorType); terr != nil { return internalServerError("Error removing unverified factors. %s", terr) } return nil @@ -643,7 +664,7 @@ func (a *API) verifyPhoneFactor(w http.ResponseWriter, r *http.Request, params * if terr = models.InvalidateSessionsWithAALLessThan(tx, user.ID, models.AAL2.String()); terr != nil { return internalServerError("Failed to update sessions. %s", terr) } - if terr = models.DeleteUnverifiedFactors(tx, user); terr != nil { + if terr = models.DeleteUnverifiedFactors(tx, user, factor.FactorType); terr != nil { return internalServerError("Error removing unverified factors. %s", terr) } return nil diff --git a/internal/api/mfa_test.go b/internal/api/mfa_test.go index d88eebb918..71691d8477 100644 --- a/internal/api/mfa_test.go +++ b/internal/api/mfa_test.go @@ -198,6 +198,87 @@ func (ts *MFATestSuite) TestEnrollFactor() { } } +func (ts *MFATestSuite) TestDuplicateEnrollPhoneFactor() { + testPhoneNumber := "+12345677889" + altPhoneNumber := "+987412444444" + friendlyName := "phone_factor" + altFriendlyName := "alt_phone_factor" + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + + var cases = []struct { + desc string + earlierFactorName string + laterFactorName string + phone string + secondPhone string + expectedCode int + expectedNumberOfFactors int + }{ + { + desc: "Phone: Only the latest factor should persist when enrolling two unverified phone factors with the same number", + earlierFactorName: friendlyName, + laterFactorName: altFriendlyName, + phone: testPhoneNumber, + secondPhone: testPhoneNumber, + expectedNumberOfFactors: 1, + }, + + { + desc: "Phone: Both factors should persist when enrolling two different unverified numbers", + earlierFactorName: friendlyName, + laterFactorName: altFriendlyName, + phone: testPhoneNumber, + secondPhone: altPhoneNumber, + expectedNumberOfFactors: 2, + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + // Delete all test factors to start from clean slate + require.NoError(ts.T(), ts.API.db.Destroy(ts.TestUser.Factors)) + _ = performEnrollFlow(ts, token, c.earlierFactorName, models.Phone, ts.TestDomain, c.phone, http.StatusOK) + + w := performEnrollFlow(ts, token, c.laterFactorName, models.Phone, ts.TestDomain, c.secondPhone, http.StatusOK) + enrollResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) + + laterFactor, err := models.FindFactorByFactorID(ts.API.db, enrollResp.ID) + require.NoError(ts.T(), err) + require.False(ts.T(), laterFactor.IsVerified()) + + require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID)) + require.Equal(ts.T(), len(ts.TestUser.Factors), c.expectedNumberOfFactors) + + }) + } +} + +func (ts *MFATestSuite) TestDuplicateEnrollPhoneFactorWithVerified() { + testPhoneNumber := "+12345677889" + friendlyName := "phone_factor" + altFriendlyName := "alt_phone_factor" + token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID) + + ts.Run("Phone: Enrolling a factor with the same number as an existing verified phone factor should result in an error", func() { + require.NoError(ts.T(), ts.API.db.Destroy(ts.TestUser.Factors)) + + // Setup verified factor + w := performEnrollFlow(ts, token, friendlyName, models.Phone, ts.TestDomain, testPhoneNumber, http.StatusOK) + enrollResp := EnrollFactorResponse{} + require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&enrollResp)) + firstFactor, err := models.FindFactorByFactorID(ts.API.db, enrollResp.ID) + require.NoError(ts.T(), err) + require.NoError(ts.T(), firstFactor.UpdateStatus(ts.API.db, models.FactorStateVerified)) + + expectedStatusCode := http.StatusUnprocessableEntity + _ = performEnrollFlow(ts, token, altFriendlyName, models.Phone, ts.TestDomain, testPhoneNumber, expectedStatusCode) + + require.NoError(ts.T(), ts.API.db.Eager("Factors").Find(ts.TestUser, ts.TestUser.ID)) + require.Equal(ts.T(), len(ts.TestUser.Factors), 1) + }) +} + func (ts *MFATestSuite) TestDuplicateTOTPEnrollsReturnExpectedMessage() { friendlyName := "mary" issuer := "https://issuer.com" diff --git a/internal/models/factor.go b/internal/models/factor.go index 06a672583a..14e483b53a 100644 --- a/internal/models/factor.go +++ b/internal/models/factor.go @@ -117,7 +117,8 @@ func ParseAuthenticationMethod(authMethod string) (AuthenticationMethod, error) } type Factor struct { - ID uuid.UUID `json:"id" db:"id"` + ID uuid.UUID `json:"id" db:"id"` + // TODO: Consider removing this nested user field. We don't use it. User User `json:"-" belongs_to:"user"` UserID uuid.UUID `json:"-" db:"user_id"` CreatedAt time.Time `json:"created_at" db:"created_at"` @@ -196,8 +197,8 @@ func FindFactorByFactorID(conn *storage.Connection, factorID uuid.UUID) (*Factor return &factor, nil } -func DeleteUnverifiedFactors(tx *storage.Connection, user *User) error { - if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Factor{}}).TableName()+" WHERE user_id = ? and status = ?", user.ID, FactorStateUnverified.String()).Exec(); err != nil { +func DeleteUnverifiedFactors(tx *storage.Connection, user *User, factorType string) error { + if err := tx.RawQuery("DELETE FROM "+(&pop.Model{Value: Factor{}}).TableName()+" WHERE user_id = ? and status = ? and factor_type = ?", user.ID, FactorStateUnverified.String(), factorType).Exec(); err != nil { return err } @@ -263,6 +264,14 @@ func (f *Factor) IsVerified() bool { return f.Status == FactorStateVerified.String() } +func (f *Factor) IsUnverified() bool { + return f.Status == FactorStateUnverified.String() +} + +func (f *Factor) IsPhoneFactor() bool { + return f.FactorType == Phone +} + func (f *Factor) FindChallengeByID(conn *storage.Connection, challengeID uuid.UUID) (*Challenge, error) { var challenge Challenge err := conn.Q().Where("id = ? and factor_id = ?", challengeID, f.ID).First(&challenge)