diff --git a/internal/api/admin_test.go b/internal/api/admin_test.go index 615bba6c70..c576594147 100644 --- a/internal/api/admin_test.go +++ b/internal/api/admin_test.go @@ -594,6 +594,11 @@ func (ts *AdminTestSuite) TestAdminUserSoftDeletion() { "provider": "email", } require.NoError(ts.T(), ts.API.db.Create(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetPhone(), u.PhoneChangeToken, models.PhoneChangeToken)) // create user identities _, err = ts.API.createNewIdentity(ts.API.db, u, "email", map[string]interface{}{ diff --git a/internal/api/external_test.go b/internal/api/external_test.go index ca82aede3c..09fdcc433b 100644 --- a/internal/api/external_test.go +++ b/internal/api/external_test.go @@ -56,6 +56,10 @@ func (ts *ExternalTestSuite) createUser(providerId string, email string, name st ts.Require().NoError(err, "Error making new user") ts.Require().NoError(ts.API.db.Create(u), "Error creating user") + if confirmationToken != "" { + ts.Require().NoError(models.CreateOneTimeToken(ts.API.db, u.ID, email, u.ConfirmationToken, models.ConfirmationToken), "Error creating one-time confirmation/invite token") + } + i, err := models.NewIdentity(u, "email", map[string]interface{}{ "sub": u.ID.String(), "email": email, diff --git a/internal/api/invite_test.go b/internal/api/invite_test.go index 1ced4caeb2..1d502adc86 100644 --- a/internal/api/invite_test.go +++ b/internal/api/invite_test.go @@ -211,6 +211,7 @@ func (ts *InviteTestSuite) TestVerifyInvite() { user.ConfirmationToken = crypto.GenerateTokenHash(c.email, c.requestBody["token"].(string)) require.NoError(ts.T(), err) require.NoError(ts.T(), ts.API.db.Create(user)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, user.ID, user.GetEmail(), user.ConfirmationToken, models.ConfirmationToken)) // Find test user _, err = models.FindUserByEmailAndAudience(ts.API.db, c.email, ts.Config.JWT.Aud) diff --git a/internal/api/resend_test.go b/internal/api/resend_test.go index 122415dd78..83c58c4e49 100644 --- a/internal/api/resend_test.go +++ b/internal/api/resend_test.go @@ -128,6 +128,8 @@ func (ts *ResendTestSuite) TestResendSuccess() { u.EmailChangeSentAt = &now u.EmailChangeTokenNew = "123456" require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew)) phoneUser, err := models.NewUser("1234567890", "", "password", ts.Config.JWT.Aud, nil) require.NoError(ts.T(), err, "Error creating test user model") @@ -135,6 +137,7 @@ func (ts *ResendTestSuite) TestResendSuccess() { phoneUser.EmailChangeSentAt = &now phoneUser.EmailChangeTokenNew = "123456" require.NoError(ts.T(), ts.API.db.Create(phoneUser), "Error saving new test user") + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, phoneUser.ID, phoneUser.EmailChange, phoneUser.EmailChangeTokenNew, models.EmailChangeTokenNew)) emailUser, err := models.NewUser("", "bar@example.com", "password", ts.Config.JWT.Aud, nil) require.NoError(ts.T(), err, "Error creating test user model") @@ -142,6 +145,7 @@ func (ts *ResendTestSuite) TestResendSuccess() { phoneUser.PhoneChangeSentAt = &now phoneUser.PhoneChangeToken = "123456" require.NoError(ts.T(), ts.API.db.Create(emailUser), "Error saving new test user") + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, phoneUser.ID, phoneUser.PhoneChange, phoneUser.PhoneChangeToken, models.PhoneChangeToken)) cases := []struct { desc string diff --git a/internal/api/signup_test.go b/internal/api/signup_test.go index 36b8feb66a..3f47832616 100644 --- a/internal/api/signup_test.go +++ b/internal/api/signup_test.go @@ -4,13 +4,14 @@ import ( "bytes" "encoding/json" "fmt" - mail "github.com/supabase/auth/internal/mailer" "net/http" "net/http/httptest" "net/url" "testing" "time" + mail "github.com/supabase/auth/internal/mailer" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -127,6 +128,7 @@ func (ts *SignupTestSuite) TestVerifySignup() { user.ConfirmationSentAt = &now require.NoError(ts.T(), err) require.NoError(ts.T(), ts.API.db.Create(user)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, user.ID, user.GetEmail(), user.ConfirmationToken, models.ConfirmationToken)) // Find test user u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index 1b4c64efd5..73386bebe6 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -21,26 +21,6 @@ import ( "github.com/supabase/auth/internal/models" ) -type VerifyVariant int - -const ( - VerifyWithoutOTT VerifyVariant = iota - VerifyWithOTT -) - -func (v VerifyVariant) String() string { - switch v { - case VerifyWithoutOTT: - return "WithoutOTT" - - case VerifyWithOTT: - return "WithOTT" - - default: - panic("VerifyVariant: unreachable code") - } -} - type VerifyTestSuite struct { suite.Suite API *API @@ -69,21 +49,6 @@ func (ts *VerifyTestSuite) SetupTest() { require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") } -func (ts *VerifyTestSuite) VerifyWithVariants(fn func(variant VerifyVariant)) { - variants := []VerifyVariant{ - VerifyWithoutOTT, - VerifyWithOTT, - } - - for _, v := range variants { - variant := v - - ts.Run(variant.String(), func() { - fn(variant) - }) - } -} - func (ts *VerifyTestSuite) TestVerifyPasswordRecovery() { // modify config so we don't hit rate limit from requesting recovery twice in 60s ts.Config.SMTP.MaxFrequency = 60 @@ -117,60 +82,54 @@ func (ts *VerifyTestSuite) TestVerifyPasswordRecovery() { }, } - ts.VerifyWithVariants(func(variant VerifyVariant) { - for _, c := range cases { - ts.Run(c.desc, func() { - // Reset user - u.EmailConfirmedAt = nil - require.NoError(ts.T(), ts.API.db.Update(u)) - require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) - - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + for _, c := range cases { + ts.Run(c.desc, func() { + // Reset user + u.EmailConfirmedAt = nil + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) - req.Header.Set("Content-Type", "application/json") + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusOK, w.Code) + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") - u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) - assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) - assert.False(ts.T(), u.IsConfirmed()) + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) - recoveryToken := u.RecoveryToken + assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) - if variant == VerifyWithoutOTT { - require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) - } + recoveryToken := u.RecoveryToken - reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.RecoveryVerification, recoveryToken) - req = httptest.NewRequest(http.MethodGet, reqURL, nil) + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.RecoveryVerification, recoveryToken) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) - u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - assert.True(ts.T(), u.IsConfirmed()) + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.True(ts.T(), u.IsConfirmed()) - if c.isPKCE { - rURL, _ := w.Result().Location() + if c.isPKCE { + rURL, _ := w.Result().Location() - f, err := url.ParseQuery(rURL.RawQuery) - require.NoError(ts.T(), err) - assert.NotEmpty(ts.T(), f.Get("code")) - } - }) - } - }) + f, err := url.ParseQuery(rURL.RawQuery) + require.NoError(ts.T(), err) + assert.NotEmpty(ts.T(), f.Get("code")) + } + }) + } } func (ts *VerifyTestSuite) TestVerifySecureEmailChange() { @@ -208,118 +167,112 @@ func (ts *VerifyTestSuite) TestVerifySecureEmailChange() { }, } - ts.VerifyWithVariants(func(variant VerifyVariant) { - for _, c := range cases { - u, err := models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - // reset user - u.EmailChangeSentAt = nil - u.EmailChangeTokenCurrent = "" - u.EmailChangeTokenNew = "" - require.NoError(ts.T(), ts.API.db.Update(u)) - require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) - - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - - // Setup request - req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Generate access token for request and a mock session - var token string - session, err := models.NewSession(u.ID, nil) - require.NoError(ts.T(), err) - require.NoError(ts.T(), ts.API.db.Create(session)) - - token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.MagicLink) - require.NoError(ts.T(), err) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusOK, w.Code) - - u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - currentTokenHash := u.EmailChangeTokenCurrent - newTokenHash := u.EmailChangeTokenNew - - if variant == VerifyWithoutOTT { - require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) - } - - u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - assert.WithinDuration(ts.T(), time.Now(), *u.EmailChangeSentAt, 1*time.Second) - assert.False(ts.T(), u.IsConfirmed()) - - // Verify new email - reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, newTokenHash) - req = httptest.NewRequest(http.MethodGet, reqURL, nil) - - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - - require.Equal(ts.T(), http.StatusSeeOther, w.Code) - urlVal, err := url.Parse(w.Result().Header.Get("Location")) - ts.Require().NoError(err, "redirect url parse failed") - var v url.Values - if !c.isPKCE { - v, err = url.ParseQuery(urlVal.Fragment) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("message")) - } else if c.isPKCE { - v, err = url.ParseQuery(urlVal.RawQuery) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("message")) - - v, err = url.ParseQuery(urlVal.Fragment) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("message")) - } - - u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - assert.Equal(ts.T(), singleConfirmation, u.EmailChangeConfirmStatus) - - // Verify old email - reqURL = fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, currentTokenHash) - req = httptest.NewRequest(http.MethodGet, reqURL, nil) + for _, c := range cases { + u, err := models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + // reset user + u.EmailChangeSentAt = nil + u.EmailChangeTokenCurrent = "" + u.EmailChangeTokenNew = "" + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + // Setup request + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Generate access token for request and a mock session + var token string + session, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(session)) + + token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.MagicLink) + require.NoError(ts.T(), err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + currentTokenHash := u.EmailChangeTokenCurrent + newTokenHash := u.EmailChangeTokenNew + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + assert.WithinDuration(ts.T(), time.Now(), *u.EmailChangeSentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) + + // Verify new email + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, newTokenHash) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusSeeOther, w.Code) + urlVal, err := url.Parse(w.Result().Header.Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + var v url.Values + if !c.isPKCE { + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + } else if c.isPKCE { + v, err = url.ParseQuery(urlVal.RawQuery) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + } - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusSeeOther, w.Code) - - urlVal, err = url.Parse(w.Header().Get("Location")) - ts.Require().NoError(err, "redirect url parse failed") - if !c.isPKCE { - v, err = url.ParseQuery(urlVal.Fragment) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("access_token")) - ts.Require().NotEmpty(v.Get("expires_in")) - ts.Require().NotEmpty(v.Get("refresh_token")) - } else if c.isPKCE { - v, err = url.ParseQuery(urlVal.RawQuery) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("code")) - } + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), singleConfirmation, u.EmailChangeConfirmStatus) + + // Verify old email + reqURL = fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, currentTokenHash) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusSeeOther, w.Code) + + urlVal, err = url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + if !c.isPKCE { + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("access_token")) + ts.Require().NotEmpty(v.Get("expires_in")) + ts.Require().NotEmpty(v.Get("refresh_token")) + } else if c.isPKCE { + v, err = url.ParseQuery(urlVal.RawQuery) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("code")) + } - // user's email should've been updated to newEmail - u, err = models.FindUserByEmailAndAudience(ts.API.db, c.newEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - require.Equal(ts.T(), zeroConfirmation, u.EmailChangeConfirmStatus) + // user's email should've been updated to newEmail + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.newEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.Equal(ts.T(), zeroConfirmation, u.EmailChangeConfirmStatus) - // Reset confirmation status after each test - u.EmailConfirmedAt = nil - require.NoError(ts.T(), ts.API.db.Update(u)) - } - }) + // Reset confirmation status after each test + u.EmailConfirmedAt = nil + require.NoError(ts.T(), ts.API.db.Update(u)) + } } func (ts *VerifyTestSuite) TestExpiredConfirmationToken() { @@ -332,6 +285,7 @@ func (ts *VerifyTestSuite) TestExpiredConfirmationToken() { sentTime := time.Now().Add(-48 * time.Hour) u.ConfirmationSentAt = &sentTime require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) // Setup request reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.SignupVerification, u.ConfirmationToken) @@ -363,6 +317,8 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { u.PhoneChangeToken = "123456" u.PhoneChangeSentAt = &sentTime require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.PhoneChange, u.PhoneChangeToken, models.PhoneChangeToken)) type ResponseBody struct { Code int `json:"code"` @@ -685,6 +641,7 @@ func (ts *VerifyTestSuite) TestVerifySignupWithRedirectURLContainedPath() { sendTime := time.Now().Add(time.Hour) u.ConfirmationSentAt = &sendTime require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s&redirect_to=%s", "signup", u.ConfirmationToken, redirectURL) req := httptest.NewRequest(http.MethodGet, reqURL, nil) @@ -705,13 +662,10 @@ func (ts *VerifyTestSuite) TestVerifySignupWithRedirectURLContainedPath() { func (ts *VerifyTestSuite) TestVerifyPKCEOTP() { u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) - u.ConfirmationToken = "pkce_confirmation_token" - u.RecoveryToken = "pkce_recovery_token" t := time.Now() u.ConfirmationSentAt = &t u.RecoverySentAt = &t u.EmailChangeSentAt = &t - require.NoError(ts.T(), ts.API.db.Update(u)) cases := []struct { @@ -720,10 +674,10 @@ func (ts *VerifyTestSuite) TestVerifyPKCEOTP() { authenticationMethod models.AuthenticationMethod }{ { - desc: "Verify banned user on signup", + desc: "Verify user on signup", payload: &VerifyParams{ Type: "signup", - Token: u.ConfirmationToken, + Token: "pkce_confirmation_token", }, authenticationMethod: models.EmailSignup, }, @@ -731,7 +685,7 @@ func (ts *VerifyTestSuite) TestVerifyPKCEOTP() { desc: "Verify magiclink", payload: &VerifyParams{ Type: "magiclink", - Token: u.RecoveryToken, + Token: "pkce_recovery_token", }, authenticationMethod: models.MagicLink, }, @@ -739,8 +693,16 @@ func (ts *VerifyTestSuite) TestVerifyPKCEOTP() { for _, c := range cases { ts.Run(c.desc, func() { var buffer bytes.Buffer + // since the test user is the same, the tokens are being cleared after each successful verification attempt + // so we create them on each run + if c.payload.Type == "signup" { + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), c.payload.Token, models.ConfirmationToken)) + } else if c.payload.Type == "magiclink" { + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), c.payload.Token, models.RecoveryToken)) + } + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.payload)) - codeChallenge := "codechallengecodechallengcodechallengcodechallengcodechallenge" + c.payload.Type + codeChallenge := "codechallengecodechallengcodechallengcodechallengcodechallenge" flowState := models.NewFlowState(c.authenticationMethod.String(), codeChallenge, models.SHA256, c.authenticationMethod, &u.ID) require.NoError(ts.T(), ts.API.db.Create(flowState)) @@ -780,6 +742,10 @@ func (ts *VerifyTestSuite) TestVerifyBannedUser() { t = time.Now().Add(24 * time.Hour) u.BannedUntil = &t require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, u.GetEmail(), u.EmailChangeTokenNew, models.EmailChangeTokenNew)) cases := []struct { desc string @@ -980,48 +946,42 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() { }, } - ts.VerifyWithVariants(func(variant VerifyVariant) { - for _, caseItem := range cases { - c := caseItem - ts.Run(c.desc, func() { - // create user - require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) - - u.ConfirmationSentAt = &c.sentTime - u.RecoverySentAt = &c.sentTime - u.EmailChangeSentAt = &c.sentTime - u.PhoneChangeSentAt = &c.sentTime - - u.ConfirmationToken = c.expected.tokenHash - u.RecoveryToken = c.expected.tokenHash - u.EmailChangeTokenNew = c.expected.tokenHash - u.PhoneChangeToken = c.expected.tokenHash - - require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.ConfirmationToken, models.ConfirmationToken)) - require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.RecoveryToken, models.RecoveryToken)) - require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.EmailChangeTokenNew, models.EmailChangeTokenNew)) - require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.PhoneChangeToken, models.PhoneChangeToken)) - - if variant == VerifyWithoutOTT { - require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) - } - - require.NoError(ts.T(), ts.API.db.Update(u)) - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), c.expected.code, w.Code) - }) - } - }) + for _, caseItem := range cases { + c := caseItem + ts.Run(c.desc, func() { + // create user + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + u.ConfirmationSentAt = &c.sentTime + u.RecoverySentAt = &c.sentTime + u.EmailChangeSentAt = &c.sentTime + u.PhoneChangeSentAt = &c.sentTime + + u.ConfirmationToken = c.expected.tokenHash + u.RecoveryToken = c.expected.tokenHash + u.EmailChangeTokenNew = c.expected.tokenHash + u.PhoneChangeToken = c.expected.tokenHash + + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.RecoveryToken, models.RecoveryToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.PhoneChangeToken, models.PhoneChangeToken)) + + require.NoError(ts.T(), ts.API.db.Update(u)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), c.expected.code, w.Code) + }) + } } func (ts *VerifyTestSuite) TestSecureEmailChangeWithTokenHash() { @@ -1066,47 +1026,42 @@ func (ts *VerifyTestSuite) TestSecureEmailChangeWithTokenHash() { }, } - ts.VerifyWithVariants(func(variant VerifyVariant) { - for _, c := range cases { - ts.Run(c.desc, func() { - // Set the corresponding email change tokens - u.EmailChangeTokenCurrent = currentEmailChangeToken - u.EmailChangeTokenNew = newEmailChangeToken - require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) - - if variant == VerifyWithOTT { - require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", currentEmailChangeToken, models.EmailChangeTokenCurrent)) - require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", newEmailChangeToken, models.EmailChangeTokenNew)) - } - - currentTime := time.Now() - u.EmailChangeSentAt = ¤tTime - require.NoError(ts.T(), ts.API.db.Update(u)) - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.firstVerificationBody)) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.secondVerificationBody)) - - // Setup second request - req = httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup second response recorder - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), c.expectedStatus, w.Code) - }) + for _, c := range cases { + ts.Run(c.desc, func() { + // Set the corresponding email change tokens + u.EmailChangeTokenCurrent = currentEmailChangeToken + u.EmailChangeTokenNew = newEmailChangeToken + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) - } - }) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", currentEmailChangeToken, models.EmailChangeTokenCurrent)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", newEmailChangeToken, models.EmailChangeTokenNew)) + + currentTime := time.Now() + u.EmailChangeSentAt = ¤tTime + require.NoError(ts.T(), ts.API.db.Update(u)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.firstVerificationBody)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.secondVerificationBody)) + + // Setup second request + req = httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup second response recorder + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), c.expectedStatus, w.Code) + }) + } } func (ts *VerifyTestSuite) TestPrepRedirectURL() { diff --git a/internal/models/connection.go b/internal/models/connection.go index 95fe874156..80acccc57e 100644 --- a/internal/models/connection.go +++ b/internal/models/connection.go @@ -48,6 +48,7 @@ func TruncateAll(conn *storage.Connection) error { (&pop.Model{Value: SAMLProvider{}}).TableName(), (&pop.Model{Value: SAMLRelayState{}}).TableName(), (&pop.Model{Value: FlowState{}}).TableName(), + (&pop.Model{Value: OneTimeToken{}}).TableName(), } for _, tableName := range tables { diff --git a/internal/models/one_time_token.go b/internal/models/one_time_token.go index 18417af846..c5a2049024 100644 --- a/internal/models/one_time_token.go +++ b/internal/models/one_time_token.go @@ -178,99 +178,57 @@ func FindOneTimeToken(tx *storage.Connection, tokenHash string, tokenTypes ...On // FindUserByConfirmationToken finds users with the matching confirmation token. func FindUserByConfirmationOrRecoveryToken(tx *storage.Connection, token string) (*User, error) { ott, err := FindOneTimeToken(tx, token, ConfirmationToken, RecoveryToken) - if err != nil && !IsNotFoundError(err) { + if err != nil { return nil, err } - if ott == nil { - user, err := findUser(tx, "(confirmation_token = ? or recovery_token = ?) and is_sso_user = false", token, token) - if err != nil { - if IsNotFoundError(err) { - return nil, ConfirmationOrRecoveryTokenNotFoundError{} - } else { - return nil, err - } - } - - return user, nil - } - return FindUserByID(tx, ott.UserID) } // FindUserByConfirmationToken finds users with the matching confirmation token. func FindUserByConfirmationToken(tx *storage.Connection, token string) (*User, error) { ott, err := FindOneTimeToken(tx, token, ConfirmationToken) - if err != nil && !IsNotFoundError(err) { + if err != nil { return nil, err } - if ott == nil { - user, err := findUser(tx, "confirmation_token = ? and is_sso_user = false", token) - if err != nil { - if IsNotFoundError(err) { - return nil, ConfirmationTokenNotFoundError{} - } else { - return nil, err - } - } - - return user, nil - } - return FindUserByID(tx, ott.UserID) } // FindUserByRecoveryToken finds a user with the matching recovery token. func FindUserByRecoveryToken(tx *storage.Connection, token string) (*User, error) { ott, err := FindOneTimeToken(tx, token, RecoveryToken) - if err != nil && !IsNotFoundError(err) { + if err != nil { return nil, err } - if ott == nil { - return findUser(tx, "recovery_token = ? and is_sso_user = false", token) - } - return FindUserByID(tx, ott.UserID) } // FindUserByEmailChangeToken finds a user with the matching email change token. func FindUserByEmailChangeToken(tx *storage.Connection, token string) (*User, error) { ott, err := FindOneTimeToken(tx, token, EmailChangeTokenCurrent, EmailChangeTokenNew) - if err != nil && !IsNotFoundError(err) { + if err != nil { return nil, err } - if ott == nil { - return findUser(tx, "is_sso_user = false and (email_change_token_current = ? or email_change_token_new = ?)", token, token) - } - return FindUserByID(tx, ott.UserID) } // FindUserByEmailChangeCurrentAndAudience finds a user with the matching email change and audience. func FindUserByEmailChangeCurrentAndAudience(tx *storage.Connection, email, token, aud string) (*User, error) { ott, err := FindOneTimeToken(tx, token, EmailChangeTokenCurrent) - if err != nil && !IsNotFoundError(err) { + if err != nil { return nil, err } if ott == nil { ott, err = FindOneTimeToken(tx, "pkce_"+token, EmailChangeTokenCurrent) - if err != nil && !IsNotFoundError(err) { + if err != nil { return nil, err } } - if ott == nil { - return findUser( - tx, - "instance_id = ? and LOWER(email) = ? and aud = ? and is_sso_user = false and (email_change_token_current = 'pkce_' || ? or email_change_token_current = ?)", - uuid.Nil, strings.ToLower(email), aud, token, token, - ) - } - user, err := FindUserByID(tx, ott.UserID) if err != nil { return nil, err @@ -297,14 +255,6 @@ func FindUserByEmailChangeNewAndAudience(tx *storage.Connection, email, token, a } } - if ott == nil { - return findUser( - tx, - "instance_id = ? and LOWER(email_change) = ? and aud = ? and is_sso_user = false and (email_change_token_new = 'pkce_' || ? or email_change_token_new = ?)", - uuid.Nil, strings.ToLower(email), aud, token, token, - ) - } - user, err := FindUserByID(tx, ott.UserID) if err != nil { return nil, err diff --git a/internal/models/user_test.go b/internal/models/user_test.go index 3b3438608c..6c915f6afb 100644 --- a/internal/models/user_test.go +++ b/internal/models/user_test.go @@ -83,8 +83,10 @@ func (ts *UserTestSuite) TestUpdateUserMetadata() { func (ts *UserTestSuite) TestFindUserByConfirmationToken() { u := ts.createUser() + tokenHash := "test_confirmation_token" + require.NoError(ts.T(), CreateOneTimeToken(ts.db, u.ID, "relates_to not used", tokenHash, ConfirmationToken)) - n, err := FindUserByConfirmationToken(ts.db, u.ConfirmationToken) + n, err := FindUserByConfirmationToken(ts.db, tokenHash) require.NoError(ts.T(), err) require.Equal(ts.T(), u.ID, n.ID) } @@ -136,14 +138,11 @@ func (ts *UserTestSuite) TestFindUserByID() { func (ts *UserTestSuite) TestFindUserByRecoveryToken() { u := ts.createUser() - u.RecoveryToken = "asdf" + tokenHash := "test_recovery_token" + require.NoError(ts.T(), CreateOneTimeToken(ts.db, u.ID, "relates_to not used", tokenHash, RecoveryToken)) - err := ts.db.Update(u) + n, err := FindUserByRecoveryToken(ts.db, tokenHash) require.NoError(ts.T(), err) - - n, err := FindUserByRecoveryToken(ts.db, u.RecoveryToken) - require.NoError(ts.T(), err) - require.Equal(ts.T(), u.ID, n.ID) }