Skip to content

Commit

Permalink
feat: calculate aal without transaction (#1437)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?

First of a few refactoring PRs:
- Change `NewSession` to take in a `UserID` and `FactorID` 
- Load identities outside of AAL and AMR calculation so
`CalculateAALAndAMR` doesn't touch the database.

This will help with ensuring that the transaction doesn't run for too
long and aid with the Hooks Implementation.

---------

Co-authored-by: joel <joel@joels-MacBook-Pro.local>
  • Loading branch information
J0 and joel authored Feb 19, 2024
1 parent 7e10d45 commit 8dae661
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 64 deletions.
3 changes: 1 addition & 2 deletions internal/api/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() {
u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

s, err := models.NewSession()
s, err := models.NewSession(u.ID, nil)
require.NoError(ts.T(), err)
s.UserID = u.ID
require.NoError(ts.T(), ts.API.db.Create(s))

require.NoError(ts.T(), ts.API.db.Load(s))
Expand Down
16 changes: 4 additions & 12 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,8 @@ func (ts *MFATestSuite) SetupTest() {
f := models.NewFactor(u, "test_factor", models.TOTP, models.FactorStateUnverified, "secretkey")
require.NoError(ts.T(), ts.API.db.Create(f), "Error saving new test factor")
// Create corresponding session
s, err := models.NewSession()
s, err := models.NewSession(u.ID, &f.ID)
require.NoError(ts.T(), err, "Error creating test session")
s.UserID = u.ID
s.FactorID = &f.ID
require.NoError(ts.T(), ts.API.db.Create(s), "Error saving test session")

u, err = models.FindUserByEmailAndAudience(ts.API.db, ts.TestEmail, ts.Config.JWT.Aud)
Expand Down Expand Up @@ -223,10 +221,8 @@ func (ts *MFATestSuite) TestMFAVerifyFactor() {
require.NoError(ts.T(), ts.API.db.Update(f), "Error updating new test factor")

// Create session to be invalidated
secondarySession, err := models.NewSession()
secondarySession, err := models.NewSession(ts.TestUser.ID, &f.ID)
require.NoError(ts.T(), err, "Error creating test session")
secondarySession.UserID = ts.TestUser.ID
secondarySession.FactorID = &f.ID
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

token := ts.generateToken(ts.TestUser, r.SessionId)
Expand Down Expand Up @@ -304,10 +300,8 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
factors, err := models.FindFactorsByUser(ts.API.db, ts.TestUser)
require.NoError(ts.T(), err, "error finding factors")
f := factors[0]
secondarySession, err = models.NewSession()
secondarySession, err = models.NewSession(ts.TestUser.ID, &f.ID)
require.NoError(ts.T(), err, "Error creating test session")
secondarySession.UserID = ts.TestUser.ID
secondarySession.FactorID = &f.ID
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

sharedSecret := ts.TestOTPKey.Secret()
Expand Down Expand Up @@ -342,10 +336,8 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {
var secondarySession *models.Session
f := ts.TestUser.Factors[0]
secondarySession, err := models.NewSession()
secondarySession, err := models.NewSession(ts.TestUser.ID, &f.ID)
require.NoError(ts.T(), err, "Error creating test session")
secondarySession.UserID = ts.TestUser.ID
secondarySession.FactorID = &f.ID
require.NoError(ts.T(), ts.API.db.Create(secondarySession), "Error saving test session")

sharedSecret := ts.TestOTPKey.Secret()
Expand Down
7 changes: 5 additions & 2 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u
if terr != nil {
return "", 0, terr
}
aal, amr, terr = session.CalculateAALAndAMR(tx)
aal, amr, terr = session.CalculateAALAndAMR(user)
if terr != nil {
return "", 0, terr
}
Expand Down Expand Up @@ -447,12 +447,15 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection,
if terr != nil {
return terr
}
if err := tx.Load(user, "Identities"); err != nil {
return err
}
// Swap to ensure current token is the latest one
refreshToken, terr = models.GrantRefreshTokenSwap(r, tx, user, currentToken)
if terr != nil {
return terr
}
aal, _, terr := session.CalculateAALAndAMR(tx)
aal, _, terr := session.CalculateAALAndAMR(user)
if terr != nil {
return terr
}
Expand Down
8 changes: 1 addition & 7 deletions internal/models/refresh_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,11 @@ func createRefreshToken(tx *storage.Connection, user *User, oldToken *RefreshTok
}

if token.SessionId == nil {
session, err := NewSession()
session, err := NewSession(user.ID, params.FactorID)
if err != nil {
return nil, errors.Wrap(err, "error instantiating new session object")
}

session.UserID = user.ID

if params.FactorID != nil {
session.FactorID = params.FactorID
}

if params.SessionNotAfter != nil {
session.NotAfter = params.SessionNotAfter
}
Expand Down
54 changes: 19 additions & 35 deletions internal/models/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,14 +154,16 @@ func (s *Session) DetermineTag(tags []string) string {
return tags[0]
}

func NewSession() (*Session, error) {
func NewSession(userID uuid.UUID, factorID *uuid.UUID) (*Session, error) {
id := uuid.Must(uuid.NewV4())

defaultAAL := AAL1.String()

session := &Session{
ID: id,
AAL: &defaultAAL,
ID: id,
AAL: &defaultAAL,
UserID: userID,
FactorID: factorID,
}

return session, nil
Expand Down Expand Up @@ -280,7 +282,7 @@ func (s *Session) UpdateAssociatedAAL(tx *storage.Connection, aal string) error
return tx.Update(s)
}

func (s *Session) CalculateAALAndAMR(tx *storage.Connection) (aal string, amr []AMREntry, err error) {
func (s *Session) CalculateAALAndAMR(user *User) (aal string, amr []AMREntry, err error) {
amr, aal = []AMREntry{}, AAL1.String()
for _, claim := range s.AMRClaims {
if *claim.AuthenticationMethod == TOTPSignIn.String() {
Expand All @@ -290,40 +292,22 @@ func (s *Session) CalculateAALAndAMR(tx *storage.Connection) (aal string, amr []
}

// makes sure that the AMR claims are always ordered most-recent first

// sort in ascending order
sort.Sort(sortAMREntries{
Array: amr,
})

// now reverse for descending order
_ = sort.Reverse(sortAMREntries{
sort.Sort(sort.Reverse(sortAMREntries{
Array: amr,
})

lastIndex := len(amr) - 1

if lastIndex > -1 && amr[lastIndex].Method == SSOSAML.String() {
// initial AMR claim is from sso/saml, we need to add information
// about the provider that was used for the authentication
identities, err := FindIdentitiesByUserID(tx, s.UserID)
if err != nil {
return aal, amr, err
}
}))

if len(identities) == 1 {
identity := identities[0]

if strings.HasPrefix(identity.Provider, "sso:") {
amr[lastIndex].Provider = strings.TrimPrefix(identity.Provider, "sso:")
}
}

// otherwise we can't identify that this user account has only
// one SSO identity, so we are not encoding the provider at
// this time
if len(amr) > 0 && amr[len(amr)-1].Method == SSOSAML.String() {
return aal, amr, nil
}

// initial AMR claim is from sso/saml, we need to add information
// about the provider that was used for the authentication
identities := user.Identities
if len(identities) == 1 && identities[0].IsForSSOProvider() {
amr[len(amr)-1].Provider = strings.TrimPrefix(identities[0].Provider, "sso:")
}
// otherwise we can't identify that this user account has only
// one SSO identity, so we are not encoding the provider at
// this time
return aal, amr, nil
}

Expand Down
10 changes: 4 additions & 6 deletions internal/models/sessions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,8 @@ func TestSession(t *testing.T) {
func (ts *SessionsTestSuite) TestFindBySessionIDWithForUpdate() {
u, err := FindUserByEmailAndAudience(ts.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
session, err := NewSession()
session, err := NewSession(u.ID, nil)
require.NoError(ts.T(), err)
session.UserID = u.ID
require.NoError(ts.T(), ts.db.Create(session))

found, err := FindSessionByID(ts.db, session.ID, true)
Expand All @@ -58,9 +57,8 @@ func (ts *SessionsTestSuite) TestCalculateAALAndAMR() {
totalDistinctClaims := 2
u, err := FindUserByEmailAndAudience(ts.db, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
session, err := NewSession()
session, err := NewSession(u.ID, nil)
require.NoError(ts.T(), err)
session.UserID = u.ID
require.NoError(ts.T(), ts.db.Create(session))

err = AddClaimToSession(ts.db, session.ID, PasswordGrant)
Expand All @@ -72,7 +70,7 @@ func (ts *SessionsTestSuite) TestCalculateAALAndAMR() {
session, err = FindSessionByID(ts.db, session.ID, false)
require.NoError(ts.T(), err)

aal, amr, err := session.CalculateAALAndAMR(ts.db)
aal, amr, err := session.CalculateAALAndAMR(u)
require.NoError(ts.T(), err)
require.Equal(ts.T(), AAL2.String(), aal)
require.Equal(ts.T(), totalDistinctClaims, len(amr))
Expand All @@ -83,7 +81,7 @@ func (ts *SessionsTestSuite) TestCalculateAALAndAMR() {
session, err = FindSessionByID(ts.db, session.ID, false)
require.NoError(ts.T(), err)

aal, amr, err = session.CalculateAALAndAMR(ts.db)
aal, amr, err = session.CalculateAALAndAMR(u)
require.NoError(ts.T(), err)

require.Equal(ts.T(), AAL2.String(), aal)
Expand Down

0 comments on commit 8dae661

Please sign in to comment.