Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: clean up test setup in MFA tests #1452

Merged
merged 3 commits into from
Mar 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 44 additions & 60 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ import (

type MFATestSuite struct {
suite.Suite
API *API
Config *conf.GlobalConfiguration
TestDomain string
TestEmail string
TestOTPKey *otp.Key
TestPassword string
TestUser *models.User
TestSession *models.Session
API *API
Config *conf.GlobalConfiguration
TestDomain string
TestEmail string
TestOTPKey *otp.Key
TestPassword string
TestUser *models.User
TestSession *models.Session
TestSecondarySession *models.Session
}

func TestMFA(t *testing.T) {
Expand Down Expand Up @@ -70,6 +71,12 @@ func (ts *MFATestSuite) SetupTest() {
ts.TestUser = u
ts.TestSession = s

secondarySession, err := models.NewSession(ts.TestUser.ID, &f.ID)
require.NoError(ts.T(), err, "Error creating test session")
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

ts.TestSecondarySession = secondarySession

// Generate TOTP related settings
testDomain := strings.Split(ts.TestEmail, "@")[1]
ts.TestDomain = testDomain
Expand All @@ -83,7 +90,7 @@ func (ts *MFATestSuite) SetupTest() {

}

func (ts *MFATestSuite) generateToken(user *models.User, sessionId *uuid.UUID) string {
func (ts *MFATestSuite) generateAAL1Token(user *models.User, sessionId *uuid.UUID) string {
token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, user, sessionId, models.TOTPSignIn)
require.NoError(ts.T(), err, "Error generating access token")
return token
Expand All @@ -93,8 +100,7 @@ func (ts *MFATestSuite) TestEnrollFactor() {
testFriendlyName := "bob"
alternativeFriendlyName := "john"

token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, ts.TestUser, nil, models.TOTPSignIn)
require.NoError(ts.T(), err)
token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)

var cases = []struct {
desc string
Expand Down Expand Up @@ -134,15 +140,14 @@ func (ts *MFATestSuite) TestEnrollFactor() {
}
for _, c := range cases {
ts.Run(c.desc, func() {

w := performEnrollFlow(ts, token, c.friendlyName, c.factorType, c.issuer, c.expectedCode)

factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
ts.Require().NoError(err)
latestFactor := factors[len(factors)-1]
require.False(ts.T(), latestFactor.IsVerified())
addedFactor := factors[len(factors)-1]
require.False(ts.T(), addedFactor.IsVerified())
if c.friendlyName != "" && c.expectedCode == http.StatusOK {
require.Equal(ts.T(), c.friendlyName, latestFactor.FriendlyName)
require.Equal(ts.T(), c.friendlyName, addedFactor.FriendlyName)
}
if w.Code == http.StatusOK {
enrollResp := EnrollFactorResponse{}
Expand All @@ -158,13 +163,13 @@ func (ts *MFATestSuite) TestEnrollFactor() {

func (ts *MFATestSuite) TestDuplicateEnrollsReturnExpectedMessage() {
friendlyName := "mary"
token, _, err := ts.API.generateAccessToken(context.Background(), ts.API.db, ts.TestUser, nil, models.TOTPSignIn)
require.NoError(ts.T(), err)
_ = performEnrollFlow(ts, token, friendlyName, models.TOTP, "https://issuer.com", http.StatusOK)
response := performEnrollFlow(ts, token, friendlyName, models.TOTP, "https://issuer.com", http.StatusBadRequest)
issuer := "https://issuer.com"
token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)
_ = performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, http.StatusOK)
response := performEnrollFlow(ts, token, friendlyName, models.TOTP, issuer, http.StatusBadRequest)

var errorResponse HTTPError
err = json.NewDecoder(response.Body).Decode(&errorResponse)
err := json.NewDecoder(response.Body).Decode(&errorResponse)
require.NoError(ts.T(), err)

// Convert the response body to a string and check for the expected error message
Expand All @@ -175,7 +180,7 @@ func (ts *MFATestSuite) TestDuplicateEnrollsReturnExpectedMessage() {

func (ts *MFATestSuite) TestChallengeFactor() {
f := ts.TestUser.Factors[0]
token := ts.generateToken(ts.TestUser, nil)
token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)
w := performChallengeFlow(ts, f.ID, token)
require.Equal(ts.T(), http.StatusOK, w.Code)
}
Expand Down Expand Up @@ -209,7 +214,6 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
for _, v := range cases {
ts.Run(v.desc, func() {
// Authenticate users and set secret

var buffer bytes.Buffer
r, err := models.GrantAuthenticatedUser(ts.API.db, ts.TestUser, models.GrantParams{})
require.NoError(ts.T(), err)
Expand All @@ -220,12 +224,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor")

// Create session to be invalidated
secondarySession, err := models.NewSession(ts.TestUser.ID, &f.ID)
require.NoError(ts.T(), err, "Error creating test session")
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

token := ts.generateToken(ts.TestUser, r.SessionId)
token := ts.generateAAL1Token(ts.TestUser, r.SessionId)

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodPost, fmt.Sprintf("/factors/%s/verify", f.ID), &buffer)
Expand Down Expand Up @@ -258,7 +257,7 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {

if v.expectedHTTPCode == http.StatusOK {
// Ensure alternate session has been deleted
_, err = models.FindSessionByID(ts.API.db, secondarySession.ID, false)
_, err = models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false)
require.EqualError(ts.T(), err, models.SessionNotFoundError{}.Error())
}
if !v.validChallenge {
Expand All @@ -271,7 +270,6 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
}

func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {

cases := []struct {
desc string
isAAL2 bool
Expand All @@ -290,29 +288,20 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
}
for _, v := range cases {
ts.Run(v.desc, func() {
// Create User
var buffer bytes.Buffer
if v.isAAL2 {
ts.TestSession.UpdateAssociatedAAL(ts.API.db, models.AAL2.String())
}
var secondarySession *models.Session

// Create Session to test behaviour which downgrades other sessions
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err, "error finding factors")
f := factors[0]
secondarySession, err = models.NewSession(ts.TestUser.ID, &f.ID)
require.NoError(ts.T(), err, "Error creating test session")
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

sharedSecret := ts.TestOTPKey.Secret()
f.Secret = sharedSecret
err = f.UpdateStatus(ts.API.db, models.FactorStateVerified)
require.NoError(ts.T(), err)
f.Secret = ts.TestOTPKey.Secret()
require.NoError(ts.T(), f.UpdateStatus(ts.API.db, models.FactorStateVerified))
require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor")

var buffer bytes.Buffer

token := ts.generateToken(ts.TestUser, &ts.TestSession.ID)
token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)

w := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodDelete, fmt.Sprintf("/factors/%s/", f.ID), &buffer)
Expand All @@ -323,7 +312,7 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
if v.expectedHTTPCode == http.StatusOK {
_, err = models.FindFactorByFactorID(ts.API.db, f.ID)
require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error())
session, _ := models.FindSessionByID(ts.API.db, secondarySession.ID, false)
session, _ := models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false)
require.Equal(ts.T(), models.AAL1.String(), session.GetAAL())
require.Nil(ts.T(), session.FactorID)

Expand All @@ -334,19 +323,11 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
}

func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {
var secondarySession *models.Session
f := ts.TestUser.Factors[0]
secondarySession, err := models.NewSession(ts.TestUser.ID, &f.ID)
require.NoError(ts.T(), err, "Error creating test session")
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

sharedSecret := ts.TestOTPKey.Secret()
f.Secret = sharedSecret

var buffer bytes.Buffer
f := ts.TestUser.Factors[0]
f.Secret = ts.TestOTPKey.Secret()

token := ts.generateToken(ts.TestUser, &ts.TestSession.ID)
require.NoError(ts.T(), err)
token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"factor_id": f.ID,
}))
Expand All @@ -356,21 +337,22 @@ func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
_, err = models.FindFactorByFactorID(ts.API.db, f.ID)

_, err := models.FindFactorByFactorID(ts.API.db, f.ID)
require.EqualError(ts.T(), err, models.FactorNotFoundError{}.Error())
session, _ := models.FindSessionByID(ts.API.db, secondarySession.ID, false)
session, _ := models.FindSessionByID(ts.API.db, ts.TestSecondarySession.ID, false)
require.Equal(ts.T(), models.AAL1.String(), session.GetAAL())
require.Nil(ts.T(), session.FactorID)

}

// Integration Tests
func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() {
ts.Config.Security.RefreshTokenRotationEnabled = true
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

ts.Config.Security.RefreshTokenRotationEnabled = true
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"refresh_token": accessTokenResp.RefreshToken,
Expand All @@ -394,11 +376,11 @@ func (ts *MFATestSuite) TestSessionsMaintainAALOnRefresh() {

// Performing MFA Verification followed by a sign in should return an AAL1 session and an AAL2 session
func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() {
ts.Config.Security.RefreshTokenRotationEnabled = true
resp := performTestSignupAndVerify(ts, ts.TestEmail, ts.TestPassword, true /* <- requireStatusOK */)
accessTokenResp := &AccessTokenResponse{}
require.NoError(ts.T(), json.NewDecoder(resp.Body).Decode(&accessTokenResp))

ts.Config.Security.RefreshTokenRotationEnabled = true
var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"email": ts.TestEmail,
Expand All @@ -414,15 +396,18 @@ func (ts *MFATestSuite) TestMFAFollowedByPasswordSignIn() {
require.NoError(ts.T(), json.NewDecoder(w.Body).Decode(&data))
ctx, err := ts.API.parseJWTClaims(data.Token, req)
require.NoError(ts.T(), err)

ctx, err = ts.API.maybeLoadUserOrSession(ctx)
require.NoError(ts.T(), err)

require.Equal(ts.T(), models.AAL1.String(), getSession(ctx).GetAAL())
session, err := models.FindSessionByUserID(ts.API.db, accessTokenResp.User.ID)
require.NoError(ts.T(), err)
require.True(ts.T(), session.IsAAL2())
}

func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenResponse) {
ts.API.config.Mailer.Autoconfirm = true
var buffer bytes.Buffer

require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
Expand All @@ -433,7 +418,6 @@ func signUp(ts *MFATestSuite, email, password string) (signUpResp AccessTokenRes
// Setup request
req := httptest.NewRequest(http.MethodPost, "http://localhost/signup", &buffer)
req.Header.Set("Content-Type", "application/json")
ts.API.config.Mailer.Autoconfirm = true
w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), http.StatusOK, w.Code)
Expand Down
37 changes: 22 additions & 15 deletions internal/models/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,45 +53,52 @@ func (ts *SessionsTestSuite) TestFindBySessionIDWithForUpdate() {
require.Equal(ts.T(), session.ID, found.ID)
}

func (ts *SessionsTestSuite) AddClaimAndReloadSession(session *Session, claim AuthenticationMethod) *Session {
err := AddClaimToSession(ts.db, session.ID, claim)
require.NoError(ts.T(), err)
session, err = FindSessionByID(ts.db, session.ID, false)
require.NoError(ts.T(), err)
return session
}

func (ts *SessionsTestSuite) TestCalculateAALAndAMR() {
totalDistinctClaims := 2
totalDistinctClaims := 3
u, err := FindUserByEmailAndAudience(ts.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
session, err := NewSession(u.ID, nil)
require.NoError(ts.T(), err)
require.NoError(ts.T(), ts.db.Create(session))

err = AddClaimToSession(ts.db, session.ID, PasswordGrant)
require.NoError(ts.T(), err)
session = ts.AddClaimAndReloadSession(session, PasswordGrant)

firstClaimAddedTime := time.Now()
err = AddClaimToSession(ts.db, session.ID, TOTPSignIn)
require.NoError(ts.T(), err)
session, err = FindSessionByID(ts.db, session.ID, false)
require.NoError(ts.T(), err)
session = ts.AddClaimAndReloadSession(session, TOTPSignIn)

aal, amr, err := session.CalculateAALAndAMR(u)
_, _, err = session.CalculateAALAndAMR(u)
require.NoError(ts.T(), err)
require.Equal(ts.T(), AAL2.String(), aal)
require.Equal(ts.T(), totalDistinctClaims, len(amr))

err = AddClaimToSession(ts.db, session.ID, TOTPSignIn)
require.NoError(ts.T(), err)
session = ts.AddClaimAndReloadSession(session, TOTPSignIn)

session, err = FindSessionByID(ts.db, session.ID, false)
require.NoError(ts.T(), err)
session = ts.AddClaimAndReloadSession(session, SSOSAML)

aal, amr, err = session.CalculateAALAndAMR(u)
aal, amr, err := session.CalculateAALAndAMR(u)
require.NoError(ts.T(), err)

require.Equal(ts.T(), AAL2.String(), aal)
require.Equal(ts.T(), totalDistinctClaims, len(amr))

found := false
for _, claim := range session.AMRClaims {
if claim.GetAuthenticationMethod() == TOTPSignIn.String() {
require.True(ts.T(), firstClaimAddedTime.Before(claim.UpdatedAt))
found = true
}
}

for _, claim := range amr {
if claim.Method == SSOSAML.String() {
require.NotNil(ts.T(), claim.Provider)
}
}
require.True(ts.T(), found)
}
Loading