From 46d4d44911bbaa42559d9fb6e2d1d23e4331bbff Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Sun, 28 Apr 2024 11:38:10 +0200 Subject: [PATCH] feat: refactor one-time tokens for performance --- internal/api/mail.go | 70 ++- internal/api/phone.go | 23 +- internal/api/verify.go | 21 +- internal/api/verify_test.go | 517 +++++++++++------- internal/models/errors.go | 2 + internal/models/one_time_token.go | 330 +++++++++++ internal/models/user.go | 121 ++-- ...427152123_add_one_time_tokens_table.up.sql | 29 + 8 files changed, 828 insertions(+), 285 deletions(-) create mode 100644 internal/models/one_time_token.go create mode 100644 migrations/20240427152123_add_one_time_tokens_table.up.sql diff --git a/internal/api/mail.go b/internal/api/mail.go index 6f0b618d60..4a3988b3ff 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -1,12 +1,13 @@ package api import ( - "github.com/supabase/auth/internal/hooks" - mail "github.com/supabase/auth/internal/mailer" "net/http" "strings" "time" + "github.com/supabase/auth/internal/hooks" + mail "github.com/supabase/auth/internal/mailer" + "github.com/badoux/checkmail" "github.com/fatih/structs" "github.com/pkg/errors" @@ -123,6 +124,13 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { terr = tx.UpdateOnly(user, "recovery_token", "recovery_sent_at") if terr != nil { terr = errors.Wrap(terr, "Database error updating user for recovery") + return terr + } + + terr = models.CreateOneTimeToken(tx, user.ID, user.GetEmail(), user.RecoveryToken, models.RecoveryToken) + if terr != nil { + terr = errors.Wrap(terr, "Database error reating recovery token in admin") + return terr } case mail.InviteVerification: if user != nil { @@ -170,6 +178,12 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at") if terr != nil { terr = errors.Wrap(terr, "Database error updating user for invite") + return terr + } + terr = models.CreateOneTimeToken(tx, user.ID, user.GetEmail(), user.ConfirmationToken, models.ConfirmationToken) + if terr != nil { + terr = errors.Wrap(terr, "Database error reating confirmation token for invite in admin") + return terr } case mail.SignupVerification: if user != nil { @@ -202,6 +216,12 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at") if terr != nil { terr = errors.Wrap(terr, "Database error updating user for confirmation") + return terr + } + terr = models.CreateOneTimeToken(tx, user.ID, user.GetEmail(), user.ConfirmationToken, models.ConfirmationToken) + if terr != nil { + terr = errors.Wrap(terr, "Database error reating confirmation token for signup in admin") + return terr } case mail.EmailChangeCurrentVerification, mail.EmailChangeNewVerification: if !config.Mailer.SecureEmailChangeEnabled && params.Type == "email_change_current" { @@ -228,6 +248,17 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { terr = tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status") if terr != nil { terr = errors.Wrap(terr, "Database error updating user for email change") + return terr + } + terr = models.CreateOneTimeToken(tx, user.ID, user.GetEmail(), user.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent) + if terr != nil { + terr = errors.Wrap(terr, "Database error creating email change token current in admin") + return terr + } + terr = models.CreateOneTimeToken(tx, user.ID, user.EmailChange, user.EmailChangeTokenNew, models.EmailChangeTokenNew) + if terr != nil { + terr = errors.Wrap(terr, "Database error creating email change token new in admin") + return terr } default: return badRequestError(ErrorCodeValidationFailed, "Invalid email action link type requested: %v", params.Type) @@ -290,6 +321,11 @@ func (a *API) sendConfirmation(r *http.Request, tx *storage.Connection, u *model return errors.Wrap(err, "Database error updating user for confirmation") } + err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken) + if err != nil { + return errors.Wrap(err, "Database error creating confirmation token") + } + return nil } @@ -317,6 +353,11 @@ func (a *API) sendInvite(r *http.Request, tx *storage.Connection, u *models.User return errors.Wrap(err, "Database error updating user for invite") } + err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ConfirmationToken, models.ConfirmationToken) + if err != nil { + return errors.Wrap(err, "Database error creating confirmation token for invite") + } + return nil } @@ -349,6 +390,11 @@ func (a *API) sendPasswordRecovery(r *http.Request, tx *storage.Connection, u *m return errors.Wrap(err, "Database error updating user for recovery") } + err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken) + if err != nil { + return errors.Wrap(err, "Database error creating recovery token") + } + return nil } @@ -381,6 +427,11 @@ func (a *API) sendReauthenticationOtp(r *http.Request, tx *storage.Connection, u return errors.Wrap(err, "Database error updating user for reauthentication") } + err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.ReauthenticationToken, models.ReauthenticationToken) + if err != nil { + return errors.Wrap(err, "Database error creating reauthentication token") + } + return nil } @@ -416,6 +467,11 @@ func (a *API) sendMagicLink(r *http.Request, tx *storage.Connection, u *models.U return errors.Wrap(err, "Database error updating user for recovery") } + err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.RecoveryToken, models.RecoveryToken) + if err != nil { + return errors.Wrap(err, "Database error creating recovery token") + } + return nil } @@ -469,6 +525,16 @@ func (a *API) sendEmailChange(r *http.Request, tx *storage.Connection, u *models return errors.Wrap(err, "Database error updating user for email change") } + err = models.CreateOneTimeToken(tx, u.ID, u.GetEmail(), u.EmailChangeTokenCurrent, models.EmailChangeTokenCurrent) + if err != nil { + return errors.Wrap(err, "Database error creating email change token current") + } + + err = models.CreateOneTimeToken(tx, u.ID, u.EmailChange, u.EmailChangeTokenNew, models.EmailChangeTokenNew) + if err != nil { + return errors.Wrap(err, "Database error creating email change token new") + } + return nil } diff --git a/internal/api/phone.go b/internal/api/phone.go index 2144749635..df83b18329 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -95,6 +95,7 @@ func (a *API) sendPhoneConfirmation(ctx context.Context, r *http.Request, tx *st if err != nil { return "", err } + if config.Hook.SendSMS.Enabled { input := hooks.SendSMSInput{ User: user, @@ -108,7 +109,6 @@ func (a *API) sendPhoneConfirmation(ctx context.Context, r *http.Request, tx *st return "", err } } else { - messageID, err = smsProvider.SendMessage(phone, message, channel, otp) if err != nil { return messageID, err @@ -127,7 +127,26 @@ func (a *API) sendPhoneConfirmation(ctx context.Context, r *http.Request, tx *st user.ReauthenticationSentAt = &now } - return messageID, errors.Wrap(tx.UpdateOnly(user, includeFields...), "Database error updating user for confirmation") + if err := tx.UpdateOnly(user, includeFields...); err != nil { + return messageID, errors.Wrap(err, "Database error updating user for phone") + } + + switch otpType { + case phoneConfirmationOtp: + if err := models.CreateOneTimeToken(tx, user.ID, user.GetPhone(), user.ConfirmationToken, models.ConfirmationToken); err != nil { + return messageID, errors.Wrap(err, "Database error creating confirmation token for phone") + } + case phoneChangeVerification: + if err := models.CreateOneTimeToken(tx, user.ID, user.PhoneChange, user.PhoneChangeToken, models.PhoneChangeToken); err != nil { + return messageID, errors.Wrap(err, "Database error creating phone change token") + } + case phoneReauthenticationOtp: + if err := models.CreateOneTimeToken(tx, user.ID, user.GetPhone(), user.ReauthenticationToken, models.ReauthenticationToken); err != nil { + return messageID, errors.Wrap(err, "Database error creating reauthentication token for phone") + } + } + + return messageID, nil } func generateSMSFromTemplate(SMSTemplate *template.Template, otp string) (string, error) { diff --git a/internal/api/verify.go b/internal/api/verify.go index ef16bf178e..f48494b0df 100644 --- a/internal/api/verify.go +++ b/internal/api/verify.go @@ -479,11 +479,28 @@ func (a *API) emailChangeVerify(r *http.Request, conn *storage.Connection, param config := a.config if config.Mailer.SecureEmailChangeEnabled && user.EmailChangeConfirmStatus == zeroConfirmation && user.GetEmail() != "" { err := conn.Transaction(func(tx *storage.Connection) error { + currentOTT, terr := models.FindOneTimeToken(tx, params.TokenHash, models.EmailChangeTokenCurrent) + if terr != nil && !models.IsNotFoundError(terr) { + return terr + } + + newOTT, terr := models.FindOneTimeToken(tx, params.TokenHash, models.EmailChangeTokenNew) + if terr != nil && !models.IsNotFoundError(terr) { + return terr + } + user.EmailChangeConfirmStatus = singleConfirmation - if params.Token == user.EmailChangeTokenCurrent || params.TokenHash == user.EmailChangeTokenCurrent { + + if params.Token == user.EmailChangeTokenCurrent || params.TokenHash == user.EmailChangeTokenCurrent || (currentOTT != nil && params.TokenHash == currentOTT.TokenHash) { user.EmailChangeTokenCurrent = "" - } else if params.Token == user.EmailChangeTokenNew || params.TokenHash == user.EmailChangeTokenNew { + if terr := models.ClearOneTimeTokenForUser(tx, user.ID, models.EmailChangeTokenCurrent); terr != nil { + return terr + } + } else if params.Token == user.EmailChangeTokenNew || params.TokenHash == user.EmailChangeTokenNew || (newOTT != nil && params.TokenHash == newOTT.TokenHash) { user.EmailChangeTokenNew = "" + if terr := models.ClearOneTimeTokenForUser(tx, user.ID, models.EmailChangeTokenNew); terr != nil { + return terr + } } if terr := tx.UpdateOnly(user, "email_change_confirm_status", "email_change_token_current", "email_change_token_new"); terr != nil { return terr diff --git a/internal/api/verify_test.go b/internal/api/verify_test.go index a2be255dea..1b4c64efd5 100644 --- a/internal/api/verify_test.go +++ b/internal/api/verify_test.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/json" "fmt" - mail "github.com/supabase/auth/internal/mailer" "io" "net/http" "net/http/httptest" @@ -12,6 +11,8 @@ import ( "testing" "time" + mail "github.com/supabase/auth/internal/mailer" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" @@ -20,6 +21,26 @@ import ( "github.com/supabase/auth/internal/models" ) +type VerifyVariant int + +const ( + VerifyWithoutOTT VerifyVariant = iota + VerifyWithOTT +) + +func (v VerifyVariant) String() string { + switch v { + case VerifyWithoutOTT: + return "WithoutOTT" + + case VerifyWithOTT: + return "WithOTT" + + default: + panic("VerifyVariant: unreachable code") + } +} + type VerifyTestSuite struct { suite.Suite API *API @@ -48,6 +69,21 @@ func (ts *VerifyTestSuite) SetupTest() { require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") } +func (ts *VerifyTestSuite) VerifyWithVariants(fn func(variant VerifyVariant)) { + variants := []VerifyVariant{ + VerifyWithoutOTT, + VerifyWithOTT, + } + + for _, v := range variants { + variant := v + + ts.Run(variant.String(), func() { + fn(variant) + }) + } +} + func (ts *VerifyTestSuite) TestVerifyPasswordRecovery() { // modify config so we don't hit rate limit from requesting recovery twice in 60s ts.Config.SMTP.MaxFrequency = 60 @@ -81,50 +117,60 @@ func (ts *VerifyTestSuite) TestVerifyPasswordRecovery() { }, } - for _, c := range cases { - ts.Run(c.desc, func() { - // Reset user - u.EmailConfirmedAt = nil - require.NoError(ts.T(), ts.API.db.Update(u)) - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + ts.VerifyWithVariants(func(variant VerifyVariant) { + for _, c := range cases { + ts.Run(c.desc, func() { + // Reset user + u.EmailConfirmedAt = nil + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) - req.Header.Set("Content-Type", "application/json") + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusOK, w.Code) + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/recover", &buffer) + req.Header.Set("Content-Type", "application/json") - u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) - assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) - assert.False(ts.T(), u.IsConfirmed()) + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) - reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.RecoveryVerification, u.RecoveryToken) - req = httptest.NewRequest(http.MethodGet, reqURL, nil) + assert.WithinDuration(ts.T(), time.Now(), *u.RecoverySentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusSeeOther, w.Code) + recoveryToken := u.RecoveryToken - u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - assert.True(ts.T(), u.IsConfirmed()) + if variant == VerifyWithoutOTT { + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + } + + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.RecoveryVerification, recoveryToken) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) - if c.isPKCE { - rURL, _ := w.Result().Location() + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusSeeOther, w.Code) - f, err := url.ParseQuery(rURL.RawQuery) + u, err = models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) - assert.NotEmpty(ts.T(), f.Get("code")) - } - }) - } + assert.True(ts.T(), u.IsConfirmed()) + + if c.isPKCE { + rURL, _ := w.Result().Location() + + f, err := url.ParseQuery(rURL.RawQuery) + require.NoError(ts.T(), err) + assert.NotEmpty(ts.T(), f.Get("code")) + } + }) + } + }) } func (ts *VerifyTestSuite) TestVerifySecureEmailChange() { @@ -162,106 +208,124 @@ func (ts *VerifyTestSuite) TestVerifySecureEmailChange() { }, } - for _, c := range cases { - u, err := models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - u.EmailChangeSentAt = &time.Time{} - require.NoError(ts.T(), ts.API.db.Update(u)) - - // Request body - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - - // Setup request - req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Generate access token for request and a mock session - var token string - session, err := models.NewSession(u.ID, nil) - require.NoError(ts.T(), err) - require.NoError(ts.T(), ts.API.db.Create(session)) - - token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.MagicLink) - require.NoError(ts.T(), err) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) - - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), http.StatusOK, w.Code) - - u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - - assert.WithinDuration(ts.T(), time.Now(), *u.EmailChangeSentAt, 1*time.Second) - assert.False(ts.T(), u.IsConfirmed()) - - // Verify new email - reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, u.EmailChangeTokenNew) - req = httptest.NewRequest(http.MethodGet, reqURL, nil) - - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - - require.Equal(ts.T(), http.StatusSeeOther, w.Code) - urlVal, err := url.Parse(w.Result().Header.Get("Location")) - ts.Require().NoError(err, "redirect url parse failed") - var v url.Values - if !c.isPKCE { - v, err = url.ParseQuery(urlVal.Fragment) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("message")) - } else if c.isPKCE { - v, err = url.ParseQuery(urlVal.RawQuery) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("message")) - - v, err = url.ParseQuery(urlVal.Fragment) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("message")) - } + ts.VerifyWithVariants(func(variant VerifyVariant) { + for _, c := range cases { + u, err := models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) - u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - assert.Equal(ts.T(), singleConfirmation, u.EmailChangeConfirmStatus) - - // Verify old email - reqURL = fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, u.EmailChangeTokenCurrent) - req = httptest.NewRequest(http.MethodGet, reqURL, nil) - - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - require.Equal(ts.T(), http.StatusSeeOther, w.Code) - - urlVal, err = url.Parse(w.Header().Get("Location")) - ts.Require().NoError(err, "redirect url parse failed") - if !c.isPKCE { - v, err = url.ParseQuery(urlVal.Fragment) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("access_token")) - ts.Require().NotEmpty(v.Get("expires_in")) - ts.Require().NotEmpty(v.Get("refresh_token")) - } else if c.isPKCE { - v, err = url.ParseQuery(urlVal.RawQuery) - ts.Require().NoError(err) - ts.Require().NotEmpty(v.Get("code")) - } + // reset user + u.EmailChangeSentAt = nil + u.EmailChangeTokenCurrent = "" + u.EmailChangeTokenNew = "" + require.NoError(ts.T(), ts.API.db.Update(u)) + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + // Request body + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - // user's email should've been updated to newEmail - u, err = models.FindUserByEmailAndAudience(ts.API.db, c.newEmail, ts.Config.JWT.Aud) - require.NoError(ts.T(), err) - require.Equal(ts.T(), zeroConfirmation, u.EmailChangeConfirmStatus) + // Setup request + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") - // Reset confirmation status after each test - u.EmailConfirmedAt = nil - require.NoError(ts.T(), ts.API.db.Update(u)) + // Generate access token for request and a mock session + var token string + session, err := models.NewSession(u.ID, nil) + require.NoError(ts.T(), err) + require.NoError(ts.T(), ts.API.db.Create(session)) - } + token, _, err = ts.API.generateAccessToken(req, ts.API.db, u, &session.ID, models.MagicLink) + require.NoError(ts.T(), err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), http.StatusOK, w.Code) + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + currentTokenHash := u.EmailChangeTokenCurrent + newTokenHash := u.EmailChangeTokenNew + + if variant == VerifyWithoutOTT { + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + } + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + assert.WithinDuration(ts.T(), time.Now(), *u.EmailChangeSentAt, 1*time.Second) + assert.False(ts.T(), u.IsConfirmed()) + + // Verify new email + reqURL := fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, newTokenHash) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + + require.Equal(ts.T(), http.StatusSeeOther, w.Code) + urlVal, err := url.Parse(w.Result().Header.Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + var v url.Values + if !c.isPKCE { + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + } else if c.isPKCE { + v, err = url.ParseQuery(urlVal.RawQuery) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("message")) + } + + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.currentEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + assert.Equal(ts.T(), singleConfirmation, u.EmailChangeConfirmStatus) + + // Verify old email + reqURL = fmt.Sprintf("http://localhost/verify?type=%s&token=%s", mail.EmailChangeVerification, currentTokenHash) + req = httptest.NewRequest(http.MethodGet, reqURL, nil) + + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), http.StatusSeeOther, w.Code) + + urlVal, err = url.Parse(w.Header().Get("Location")) + ts.Require().NoError(err, "redirect url parse failed") + if !c.isPKCE { + v, err = url.ParseQuery(urlVal.Fragment) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("access_token")) + ts.Require().NotEmpty(v.Get("expires_in")) + ts.Require().NotEmpty(v.Get("refresh_token")) + } else if c.isPKCE { + v, err = url.ParseQuery(urlVal.RawQuery) + ts.Require().NoError(err) + ts.Require().NotEmpty(v.Get("code")) + } + + // user's email should've been updated to newEmail + u, err = models.FindUserByEmailAndAudience(ts.API.db, c.newEmail, ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + require.Equal(ts.T(), zeroConfirmation, u.EmailChangeConfirmStatus) + + // Reset confirmation status after each test + u.EmailConfirmedAt = nil + require.NoError(ts.T(), ts.API.db.Update(u)) + } + }) } func (ts *VerifyTestSuite) TestExpiredConfirmationToken() { + // verify variant testing not necessary in this test as it's testing + // the ConfirmationSentAt behavior, not the ConfirmationToken behavior + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) u.ConfirmationToken = "asdf3" @@ -391,6 +455,9 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { } func (ts *VerifyTestSuite) TestExpiredRecoveryToken() { + // verify variant testing not necessary in this test as it's testing + // the RecoverySentAt behavior, not the RecoveryToken behavior + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) u.RecoveryToken = "asdf3" @@ -411,6 +478,9 @@ func (ts *VerifyTestSuite) TestExpiredRecoveryToken() { } func (ts *VerifyTestSuite) TestVerifyPermitedCustomUri() { + // verify variant testing not necessary in this test as it's testing + // the redirect URL behavior, not the RecoveryToken behavior + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) u.RecoverySentAt = &time.Time{} @@ -454,6 +524,9 @@ func (ts *VerifyTestSuite) TestVerifyPermitedCustomUri() { } func (ts *VerifyTestSuite) TestVerifyNotPermitedCustomUri() { + // verify variant testing not necessary in this test as it's testing + // the redirect URL behavior, not the RecoveryToken behavior + u, err := models.FindUserByEmailAndAudience(ts.API.db, "test@example.com", ts.Config.JWT.Aud) require.NoError(ts.T(), err) u.RecoverySentAt = &time.Time{} @@ -497,7 +570,10 @@ func (ts *VerifyTestSuite) TestVerifyNotPermitedCustomUri() { assert.True(ts.T(), u.IsConfirmed()) } -func (ts *VerifyTestSuite) TestVerifySignupWithredirectURLContainedPath() { +func (ts *VerifyTestSuite) TestVerifySignupWithRedirectURLContainedPath() { + // verify variant testing not necessary in this test as it's testing + // the redirect URL behavior, not the RecoveryToken behavior + testCases := []struct { desc string siteURL string @@ -792,10 +868,9 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() { desc: "Valid SMS OTP", sentTime: time.Now(), body: map[string]interface{}{ - "type": smsVerification, - "tokenHash": crypto.GenerateTokenHash(u.GetPhone(), "123456"), - "token": "123456", - "phone": u.GetPhone(), + "type": smsVerification, + "token": "123456", + "phone": u.GetPhone(), }, expected: expected{ code: http.StatusOK, @@ -806,10 +881,9 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() { desc: "Valid Confirmation OTP", sentTime: time.Now(), body: map[string]interface{}{ - "type": mail.SignupVerification, - "tokenHash": crypto.GenerateTokenHash(u.GetEmail(), "123456"), - "token": "123456", - "email": u.GetEmail(), + "type": mail.SignupVerification, + "token": "123456", + "email": u.GetEmail(), }, expected: expected{ code: http.StatusOK, @@ -820,10 +894,9 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() { desc: "Valid Recovery OTP", sentTime: time.Now(), body: map[string]interface{}{ - "type": mail.RecoveryVerification, - "tokenHash": crypto.GenerateTokenHash(u.GetEmail(), "123456"), - "token": "123456", - "email": u.GetEmail(), + "type": mail.RecoveryVerification, + "token": "123456", + "email": u.GetEmail(), }, expected: expected{ code: http.StatusOK, @@ -834,10 +907,9 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() { desc: "Valid Email OTP", sentTime: time.Now(), body: map[string]interface{}{ - "type": mail.EmailOTPVerification, - "tokenHash": crypto.GenerateTokenHash(u.GetEmail(), "123456"), - "token": "123456", - "email": u.GetEmail(), + "type": mail.EmailOTPVerification, + "token": "123456", + "email": u.GetEmail(), }, expected: expected{ code: http.StatusOK, @@ -848,10 +920,9 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() { desc: "Valid Email Change OTP", sentTime: time.Now(), body: map[string]interface{}{ - "type": mail.EmailChangeVerification, - "tokenHash": crypto.GenerateTokenHash(u.EmailChange, "123456"), - "token": "123456", - "email": u.EmailChange, + "type": mail.EmailChangeVerification, + "token": "123456", + "email": u.EmailChange, }, expected: expected{ code: http.StatusOK, @@ -862,10 +933,9 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() { desc: "Valid Phone Change OTP", sentTime: time.Now(), body: map[string]interface{}{ - "type": phoneChangeVerification, - "tokenHash": crypto.GenerateTokenHash(u.PhoneChange, "123456"), - "token": "123456", - "phone": u.PhoneChange, + "type": phoneChangeVerification, + "token": "123456", + "phone": u.PhoneChange, }, expected: expected{ code: http.StatusOK, @@ -910,33 +980,48 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() { }, } - for _, caseItem := range cases { - c := caseItem - ts.Run(c.desc, func() { - // create user - u.ConfirmationSentAt = &c.sentTime - u.RecoverySentAt = &c.sentTime - u.EmailChangeSentAt = &c.sentTime - u.PhoneChangeSentAt = &c.sentTime - u.ConfirmationToken = c.expected.tokenHash - u.RecoveryToken = c.expected.tokenHash - u.EmailChangeTokenNew = c.expected.tokenHash - u.PhoneChangeToken = c.expected.tokenHash - require.NoError(ts.T(), ts.API.db.Update(u)) - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), c.expected.code, w.Code) - }) - } + ts.VerifyWithVariants(func(variant VerifyVariant) { + for _, caseItem := range cases { + c := caseItem + ts.Run(c.desc, func() { + // create user + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + u.ConfirmationSentAt = &c.sentTime + u.RecoverySentAt = &c.sentTime + u.EmailChangeSentAt = &c.sentTime + u.PhoneChangeSentAt = &c.sentTime + + u.ConfirmationToken = c.expected.tokenHash + u.RecoveryToken = c.expected.tokenHash + u.EmailChangeTokenNew = c.expected.tokenHash + u.PhoneChangeToken = c.expected.tokenHash + + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.ConfirmationToken, models.ConfirmationToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.RecoveryToken, models.RecoveryToken)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.EmailChangeTokenNew, models.EmailChangeTokenNew)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", u.PhoneChangeToken, models.PhoneChangeToken)) + + if variant == VerifyWithoutOTT { + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + } + + require.NoError(ts.T(), ts.API.db.Update(u)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.body)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), c.expected.code, w.Code) + }) + } + }) } func (ts *VerifyTestSuite) TestSecureEmailChangeWithTokenHash() { @@ -980,40 +1065,48 @@ func (ts *VerifyTestSuite) TestSecureEmailChangeWithTokenHash() { expectedStatus: http.StatusForbidden, }, } - for _, c := range cases { - ts.Run(c.desc, func() { - // Set the corresponding email change tokens - u.EmailChangeTokenCurrent = currentEmailChangeToken - u.EmailChangeTokenNew = newEmailChangeToken - - currentTime := time.Now() - u.EmailChangeSentAt = ¤tTime - require.NoError(ts.T(), ts.API.db.Update(u)) - - var buffer bytes.Buffer - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.firstVerificationBody)) - - // Setup request - req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) - req.Header.Set("Content-Type", "application/json") - // Setup response recorder - w := httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.secondVerificationBody)) - - // Setup second request - req = httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) - req.Header.Set("Content-Type", "application/json") - - // Setup second response recorder - w = httptest.NewRecorder() - ts.API.handler.ServeHTTP(w, req) - assert.Equal(ts.T(), c.expectedStatus, w.Code) - }) - - } + ts.VerifyWithVariants(func(variant VerifyVariant) { + for _, c := range cases { + ts.Run(c.desc, func() { + // Set the corresponding email change tokens + u.EmailChangeTokenCurrent = currentEmailChangeToken + u.EmailChangeTokenNew = newEmailChangeToken + require.NoError(ts.T(), models.ClearAllOneTimeTokensForUser(ts.API.db, u.ID)) + + if variant == VerifyWithOTT { + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", currentEmailChangeToken, models.EmailChangeTokenCurrent)) + require.NoError(ts.T(), models.CreateOneTimeToken(ts.API.db, u.ID, "relates_to not used", newEmailChangeToken, models.EmailChangeTokenNew)) + } + + currentTime := time.Now() + u.EmailChangeSentAt = ¤tTime + require.NoError(ts.T(), ts.API.db.Update(u)) + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.firstVerificationBody)) + + // Setup request + req := httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup response recorder + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(c.secondVerificationBody)) + + // Setup second request + req = httptest.NewRequest(http.MethodPost, "http://localhost/verify", &buffer) + req.Header.Set("Content-Type", "application/json") + + // Setup second response recorder + w = httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + assert.Equal(ts.T(), c.expectedStatus, w.Code) + }) + } + }) } func (ts *VerifyTestSuite) TestPrepRedirectURL() { diff --git a/internal/models/errors.go b/internal/models/errors.go index 7746badbd4..96f8319695 100644 --- a/internal/models/errors.go +++ b/internal/models/errors.go @@ -25,6 +25,8 @@ func IsNotFoundError(err error) bool { return true case FlowStateNotFoundError, *FlowStateNotFoundError: return true + case OneTimeTokenNotFoundError, *OneTimeTokenNotFoundError: + return true } return false } diff --git a/internal/models/one_time_token.go b/internal/models/one_time_token.go new file mode 100644 index 0000000000..f999439e80 --- /dev/null +++ b/internal/models/one_time_token.go @@ -0,0 +1,330 @@ +package models + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "strings" + "time" + + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/supabase/auth/internal/storage" +) + +type OneTimeTokenType int + +const ( + ConfirmationToken OneTimeTokenType = iota + ReauthenticationToken + RecoveryToken + EmailChangeTokenNew + EmailChangeTokenCurrent + PhoneChangeToken +) + +func (t OneTimeTokenType) String() string { + switch t { + case ConfirmationToken: + return "confirmation_token" + + case ReauthenticationToken: + return "reauthentication_token" + + case RecoveryToken: + return "recovery_token" + + case EmailChangeTokenNew: + return "email_change_token_new" + + case EmailChangeTokenCurrent: + return "email_change_token_current" + + case PhoneChangeToken: + return "phone_change_token" + + default: + panic("OneTimeToken: unreachable case") + } +} + +func ParseOneTimeTokenType(s string) (OneTimeTokenType, error) { + switch s { + case "confirmation_token": + return ConfirmationToken, nil + + case "reauthentication_token": + return ReauthenticationToken, nil + + case "recovery_token": + return RecoveryToken, nil + + case "email_change_token_new": + return EmailChangeTokenNew, nil + + case "email_change_token_current": + return EmailChangeTokenCurrent, nil + + case "phone_change_token": + return PhoneChangeToken, nil + + default: + return 0, fmt.Errorf("OneTimeTokenType: unrecognized string %q", s) + } +} + +func (t OneTimeTokenType) Value() (driver.Value, error) { + return t.String(), nil +} + +func (t *OneTimeTokenType) Scan(src interface{}) error { + s, ok := src.(string) + if !ok { + return fmt.Errorf("OneTimeTokenType: scan type is not string but is %T", src) + } + + parsed, err := ParseOneTimeTokenType(s) + if err != nil { + return err + } + + *t = parsed + return nil +} + +type OneTimeTokenNotFoundError struct { +} + +func (e OneTimeTokenNotFoundError) Error() string { + return "One-time token not found" +} + +type OneTimeToken struct { + ID uuid.UUID `json:"id" db:"id"` + + UserID uuid.UUID `json:"user_id" db:"user_id"` + TokenType OneTimeTokenType `json:"token_type" db:"token_type"` + + TokenHash string `json:"token_hash" db:"token_hash"` + RelatesTo string `json:"relates_to" db:"relates_to"` + + CreatedAt time.Time `json:"created_at" db:"created_at"` + UpdatedAt time.Time `json:"updated_at" db:"updated_at"` +} + +func (OneTimeToken) TableName() string { + return "one_time_tokens" +} + +func ClearAllOneTimeTokensForUser(tx *storage.Connection, userID uuid.UUID) error { + return tx.Q().Where("user_id = ?", userID).Delete(OneTimeToken{}) +} + +func ClearOneTimeTokenForUser(tx *storage.Connection, userID uuid.UUID, tokenType OneTimeTokenType) error { + if err := tx.Q().Where("token_type = ? and user_id = ?", tokenType, userID).Delete(OneTimeToken{}); err != nil { + return err + } + + return nil +} + +func CreateOneTimeToken(tx *storage.Connection, userID uuid.UUID, relatesTo, tokenHash string, tokenType OneTimeTokenType) error { + if err := ClearOneTimeTokenForUser(tx, userID, tokenType); err != nil { + return err + } + + oneTimeToken := &OneTimeToken{ + ID: uuid.Must(uuid.NewV4()), + UserID: userID, + TokenType: tokenType, + TokenHash: tokenHash, + RelatesTo: strings.ToLower(relatesTo), + } + + if err := tx.Eager().Create(oneTimeToken); err != nil { + return err + } + + return nil +} + +func FindOneTimeToken(tx *storage.Connection, tokenHash string, tokenTypes ...OneTimeTokenType) (*OneTimeToken, error) { + oneTimeToken := &OneTimeToken{} + + query := tx.Eager().Q() + + switch len(tokenTypes) { + case 2: + query = query.Where("(token_type = ? or token_type = ?) and token_hash = ?", tokenTypes[0], tokenTypes[1], tokenHash) + + case 1: + query = query.Where("token_type = ? and token_hash = ?", tokenTypes[0], tokenHash) + + default: + panic("FindOneTimeToken accepts only 3 or 4 arguments") + } + + if err := query.First(oneTimeToken); err != nil { + if errors.Cause(err) == sql.ErrNoRows { + return nil, OneTimeTokenNotFoundError{} + } + + return nil, errors.Wrap(err, "error finding one time token") + } + + return oneTimeToken, nil +} + +// FindUserByConfirmationToken finds users with the matching confirmation token. +func FindUserByConfirmationOrRecoveryToken(tx *storage.Connection, token string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, ConfirmationToken, RecoveryToken) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + + if ott == nil { + user, err := findUser(tx, "(confirmation_token = ? or recovery_token = ?) and is_sso_user = false", token, token) + if err != nil { + if IsNotFoundError(err) { + return nil, ConfirmationOrRecoveryTokenNotFoundError{} + } else { + return nil, err + } + } + + return user, nil + } + + return FindUserByID(tx, ott.UserID) +} + +// FindUserByConfirmationToken finds users with the matching confirmation token. +func FindUserByConfirmationToken(tx *storage.Connection, token string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, ConfirmationToken) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + + if ott == nil { + user, err := findUser(tx, "confirmation_token = ? and is_sso_user = false", token) + if err != nil { + if IsNotFoundError(err) { + return nil, ConfirmationTokenNotFoundError{} + } else { + return nil, err + } + } + + return user, nil + } + + return FindUserByID(tx, ott.UserID) +} + +// FindUserByRecoveryToken finds a user with the matching recovery token. +func FindUserByRecoveryToken(tx *storage.Connection, token string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, RecoveryToken) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + + if ott == nil { + return findUser(tx, "recovery_token = ? and is_sso_user = false", token) + } + + return FindUserByID(tx, ott.UserID) +} + +// FindUserByEmailChangeToken finds a user with the matching email change token. +func FindUserByEmailChangeToken(tx *storage.Connection, token string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, EmailChangeTokenCurrent, EmailChangeTokenNew) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + + if ott == nil { + return findUser(tx, "is_sso_user = false and (email_change_token_current = ? or email_change_token_new = ?)", token, token) + } + + return FindUserByID(tx, ott.UserID) +} + +// FindUserByEmailChangeCurrentAndAudience finds a user with the matching email change and audience. +func FindUserByEmailChangeCurrentAndAudience(tx *storage.Connection, email, token, aud string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, EmailChangeTokenCurrent) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + + if ott == nil { + ott, err = FindOneTimeToken(tx, "pkce_"+token, EmailChangeTokenCurrent) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + } + + if ott == nil { + return findUser( + tx, + "instance_id = ? and LOWER(email) = ? and aud = ? and is_sso_user = false and (email_change_token_current = 'pkce_' || ? or email_change_token_current = ?)", + uuid.Nil, strings.ToLower(email), aud, token, token, + ) + } + + user, err := FindUserByID(tx, ott.UserID) + if err != nil { + return nil, err + } + + if user.Aud != aud && strings.EqualFold(user.GetEmail(), email) { + return nil, UserNotFoundError{} + } + + return user, nil +} + +// FindUserByEmailChangeNewAndAudience finds a user with the matching email change and audience. +func FindUserByEmailChangeNewAndAudience(tx *storage.Connection, email, token, aud string) (*User, error) { + ott, err := FindOneTimeToken(tx, token, EmailChangeTokenNew) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + + if ott == nil { + ott, err = FindOneTimeToken(tx, "pkce_"+token, EmailChangeTokenNew) + if err != nil && !IsNotFoundError(err) { + return nil, err + } + } + + if ott == nil { + return findUser( + tx, + "instance_id = ? and LOWER(email_change) = ? and aud = ? and is_sso_user = false and (email_change_token_new = 'pkce_' || ? or email_change_token_new = ?)", + uuid.Nil, strings.ToLower(email), aud, token, token, + ) + } + + user, err := FindUserByID(tx, ott.UserID) + if err != nil { + return nil, err + } + + if user.Aud != aud && strings.EqualFold(user.EmailChange, email) { + return nil, UserNotFoundError{} + } + + return user, nil +} + +// FindUserForEmailChange finds a user requesting for an email change +func FindUserForEmailChange(tx *storage.Connection, email, token, aud string, secureEmailChangeEnabled bool) (*User, error) { + if secureEmailChangeEnabled { + if user, err := FindUserByEmailChangeCurrentAndAudience(tx, email, token, aud); err == nil { + return user, err + } else if !IsNotFoundError(err) { + return nil, err + } + } + return FindUserByEmailChangeNewAndAudience(tx, email, token, aud) +} diff --git a/internal/models/user.go b/internal/models/user.go index be70b13baa..270484e080 100644 --- a/internal/models/user.go +++ b/internal/models/user.go @@ -318,6 +318,10 @@ func (u *User) UpdatePassword(tx *storage.Connection, sessionID *uuid.UUID) erro return err } + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + if sessionID == nil { // log out user from all sessions to ensure reauthentication after password change return Logout(tx, u.ID) @@ -336,7 +340,15 @@ func (u *User) Authenticate(ctx context.Context, password string) bool { // ConfirmReauthentication resets the reauthentication token func (u *User) ConfirmReauthentication(tx *storage.Connection) error { u.ReauthenticationToken = "" - return tx.UpdateOnly(u, "reauthentication_token") + if err := tx.UpdateOnly(u, "reauthentication_token"); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + return nil } // Confirm resets the confimation token and sets the confirm timestamp @@ -344,7 +356,16 @@ func (u *User) Confirm(tx *storage.Connection) error { u.ConfirmationToken = "" now := time.Now() u.EmailConfirmedAt = &now - return tx.UpdateOnly(u, "confirmation_token", "email_confirmed_at") + + if err := tx.UpdateOnly(u, "confirmation_token", "email_confirmed_at"); err != nil { + return err + } + + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + + return nil } // ConfirmPhone resets the confimation token and sets the confirm timestamp @@ -352,7 +373,11 @@ func (u *User) ConfirmPhone(tx *storage.Connection) error { u.ConfirmationToken = "" now := time.Now() u.PhoneConfirmedAt = &now - return tx.UpdateOnly(u, "confirmation_token", "phone_confirmed_at") + if err := tx.UpdateOnly(u, "confirmation_token", "phone_confirmed_at"); err != nil { + return nil + } + + return ClearAllOneTimeTokensForUser(tx, u.ID) } // UpdateLastSignInAt update field last_sign_in_at for user according to specified field @@ -381,6 +406,10 @@ func (u *User) ConfirmEmailChange(tx *storage.Connection, status int) error { return err } + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + if !u.IsConfirmed() { if err := u.Confirm(tx); err != nil { return err @@ -426,6 +455,10 @@ func (u *User) ConfirmPhoneChange(tx *storage.Connection) error { return err } + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + identity, err := FindIdentityByIdAndProvider(tx, u.ID.String(), "phone") if err != nil { if IsNotFoundError(err) { @@ -450,7 +483,11 @@ func (u *User) ConfirmPhoneChange(tx *storage.Connection) error { // Recover resets the recovery token func (u *User) Recover(tx *storage.Connection) error { u.RecoveryToken = "" - return tx.UpdateOnly(u, "recovery_token") + if err := tx.UpdateOnly(u, "recovery_token"); err != nil { + return err + } + + return ClearAllOneTimeTokensForUser(tx, u.ID) } // CountOtherUsers counts how many other users exist besides the one provided @@ -471,24 +508,6 @@ func findUser(tx *storage.Connection, query string, args ...interface{}) (*User, return obj, nil } -// FindUserByConfirmationToken finds users with the matching confirmation token. -func FindUserByConfirmationOrRecoveryToken(tx *storage.Connection, token string) (*User, error) { - user, err := findUser(tx, "(confirmation_token = ? or recovery_token = ?) and is_sso_user = false", token, token) - if err != nil { - return nil, ConfirmationOrRecoveryTokenNotFoundError{} - } - return user, nil -} - -// FindUserByConfirmationToken finds users with the matching confirmation token. -func FindUserByConfirmationToken(tx *storage.Connection, token string) (*User, error) { - user, err := findUser(tx, "confirmation_token = ? and is_sso_user = false", token) - if err != nil { - return nil, ConfirmationTokenNotFoundError{} - } - return user, nil -} - // FindUserByEmailAndAudience finds a user with the matching email and audience. func FindUserByEmailAndAudience(tx *storage.Connection, email, aud string) (*User, error) { return findUser(tx, "instance_id = ? and LOWER(email) = ? and aud = ? and is_sso_user = false", uuid.Nil, strings.ToLower(email), aud) @@ -504,16 +523,6 @@ func FindUserByID(tx *storage.Connection, id uuid.UUID) (*User, error) { return findUser(tx, "instance_id = ? and id = ?", uuid.Nil, id) } -// FindUserByRecoveryToken finds a user with the matching recovery token. -func FindUserByRecoveryToken(tx *storage.Connection, token string) (*User, error) { - return findUser(tx, "recovery_token = ? and is_sso_user = false", token) -} - -// FindUserByEmailChangeToken finds a user with the matching email change token. -func FindUserByEmailChangeToken(tx *storage.Connection, token string) (*User, error) { - return findUser(tx, "is_sso_user = false and (email_change_token_current = ? or email_change_token_new = ?)", token, token) -} - // FindUserWithRefreshToken finds a user from the provided refresh token. If // forUpdate is set to true, then the SELECT statement used by the query has // the form SELECT ... FOR UPDATE SKIP LOCKED. This means that a FOR UPDATE @@ -601,41 +610,6 @@ func FindUsersInAudience(tx *storage.Connection, aud string, pageParams *Paginat return users, err } -// FindUserByEmailChangeCurrentAndAudience finds a user with the matching email change and audience. -func FindUserByEmailChangeCurrentAndAudience(tx *storage.Connection, email, token, aud string) (*User, error) { - return findUser( - tx, - "instance_id = ? and LOWER(email) = ? and aud = ? and is_sso_user = false and (email_change_token_current = 'pkce_' || ? or email_change_token_current = ?)", - uuid.Nil, strings.ToLower(email), aud, token, token, - ) -} - -// FindUserByEmailChangeNewAndAudience finds a user with the matching email change and audience. -func FindUserByEmailChangeNewAndAudience(tx *storage.Connection, email, token, aud string) (*User, error) { - return findUser( - tx, - "instance_id = ? and LOWER(email_change) = ? and aud = ? and is_sso_user = false and (email_change_token_new = 'pkce_' || ? or email_change_token_new = ?)", - uuid.Nil, strings.ToLower(email), aud, token, token, - ) -} - -// FindUserForEmailChange finds a user requesting for an email change -func FindUserForEmailChange(tx *storage.Connection, email, token, aud string, secureEmailChangeEnabled bool) (*User, error) { - if secureEmailChangeEnabled { - if user, err := FindUserByEmailChangeCurrentAndAudience(tx, email, token, aud); err == nil { - return user, err - } else if !IsNotFoundError(err) { - return nil, err - } - } - return FindUserByEmailChangeNewAndAudience(tx, email, token, aud) -} - -// FindUserByPhoneChangeAndAudience finds a user with the matching phone change and audience. -func FindUserByPhoneChangeAndAudience(tx *storage.Connection, phone, aud string) (*User, error) { - return findUser(tx, "instance_id = ? and phone_change = ? and aud = ? and is_sso_user = false", uuid.Nil, phone, aud) -} - // IsDuplicatedEmail returns whether a user exists with a matching email and audience. // If a currentUser is provided, we will need to filter out any identities that belong to the current user. func IsDuplicatedEmail(tx *storage.Connection, email, aud string, currentUser *User) (*User, error) { @@ -788,6 +762,10 @@ func (u *User) SoftDeleteUser(tx *storage.Connection) error { return err } + if err := ClearAllOneTimeTokensForUser(tx, u.ID); err != nil { + return err + } + // set raw_user_meta_data to {} userMetaDataUpdates := map[string]interface{}{} for k := range u.UserMetaData { @@ -808,6 +786,10 @@ func (u *User) SoftDeleteUser(tx *storage.Connection) error { return err } + if err := Logout(tx, u.ID); err != nil { + return err + } + return nil } @@ -859,3 +841,8 @@ func obfuscatePhone(u *User, phone string) string { func obfuscateIdentityProviderId(identity *Identity) string { return obfuscateValue(identity.UserID, identity.Provider+":"+identity.ProviderID) } + +// FindUserByPhoneChangeAndAudience finds a user with the matching phone change and audience. +func FindUserByPhoneChangeAndAudience(tx *storage.Connection, phone, aud string) (*User, error) { + return findUser(tx, "instance_id = ? and phone_change = ? and aud = ? and is_sso_user = false", uuid.Nil, phone, aud) +} diff --git a/migrations/20240427152123_add_one_time_tokens_table.up.sql b/migrations/20240427152123_add_one_time_tokens_table.up.sql new file mode 100644 index 0000000000..439b6ecd7e --- /dev/null +++ b/migrations/20240427152123_add_one_time_tokens_table.up.sql @@ -0,0 +1,29 @@ +do $$ begin + create type one_time_token_type as enum ( + 'confirmation_token', + 'reauthentication_token', + 'recovery_token', + 'email_change_token_new', + 'email_change_token_current', + 'phone_change_token' + ); +exception + when duplicate_object then null; +end $$; + + +do $$ begin + create table if not exists {{ index .Options "Namespace" }}.one_time_tokens ( + id uuid primary key, + user_id uuid not null references {{ index .Options "Namespace" }}.users on delete cascade, + token_type one_time_token_type not null, + token_hash text not null, + relates_to text not null, + created_at timestamp without time zone not null default now(), + updated_at timestamp without time zone not null default now() + ); + + create index if not exists one_time_tokens_token_hash_hash_idx on {{ index .Options "Namespace" }}.one_time_tokens using hash (token_hash); + create index if not exists one_time_tokens_relates_to_hash_idx on {{ index .Options "Namespace" }}.one_time_tokens using hash (relates_to); + create unique index if not exists one_time_tokens_user_id_token_type_key on {{ index .Options "Namespace" }}.one_time_tokens (user_id, token_type); +end $$;