diff --git a/api/errors.go b/api/errors.go index f839357def..be77a7eea5 100644 --- a/api/errors.go +++ b/api/errors.go @@ -14,6 +14,7 @@ import ( // Common error messages during signup flow var ( DuplicateEmailMsg = "A user with this email address has already been registered" + DuplicatePhoneMsg = "A user with this phone number has already been registered" UserExistsError error = errors.New("User already exists") ) diff --git a/api/otp.go b/api/otp.go index 6b8a86b812..ad2955b415 100644 --- a/api/otp.go +++ b/api/otp.go @@ -7,6 +7,7 @@ import ( "net/http" "strings" + "github.com/netlify/gotrue/api/sms_provider" "github.com/netlify/gotrue/models" "github.com/netlify/gotrue/storage" "github.com/sethvargo/go-password/password" @@ -103,8 +104,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { if err := models.NewAuditLogEntry(tx, instanceID, user, models.UserRecoveryRequestedAction, nil); err != nil { return err } - - if err := a.sendPhoneConfirmation(ctx, tx, user, params.Phone); err != nil { + smsProvider, err := sms_provider.GetSmsProvider(*config) + if err != nil { + return err + } + if err := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider); err != nil { return badRequestError("Error sending sms otp: %v", err) } return nil diff --git a/api/phone.go b/api/phone.go index c363973319..6c107fcc11 100644 --- a/api/phone.go +++ b/api/phone.go @@ -17,6 +17,11 @@ import ( const e164Format = `^[1-9]\d{1,14}$` const defaultSmsMessage = "Your code is %v" +const ( + phoneChangeOtp = "phone_change" + phoneConfirmationOtp = "confirmation" +) + // validateE165Format checks if phone number follows the E.164 format func (a *API) validateE164Format(phone string) bool { // match should never fail as long as regexp is valid @@ -29,39 +34,51 @@ func (a *API) formatPhoneNumber(phone string) string { return strings.ReplaceAll(strings.Trim(phone, "+"), " ", "") } -func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection, user *models.User, phone string) error { +// sendPhoneConfirmation sends an otp to the user's phone number +func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider) error { config := a.getConfig(ctx) - if user.ConfirmationSentAt != nil && !user.ConfirmationSentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) { + var token *string + var sentAt *time.Time + var tokenDbField, sentAtDbField string + + if otpType == phoneConfirmationOtp { + token = &user.ConfirmationToken + sentAt = user.ConfirmationSentAt + tokenDbField, sentAtDbField = "confirmation_token", "confirmation_sent_at" + } else if otpType == phoneChangeOtp { + token = &user.PhoneChangeToken + sentAt = user.PhoneChangeSentAt + tokenDbField, sentAtDbField = "phone_change_token", "phone_change_sent_at" + } else { + return internalServerError("invalid otp type") + } + + if sentAt != nil && !sentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) { return MaxFrequencyLimitError } - oldToken := user.ConfirmationToken + oldToken := *token otp, err := crypto.GenerateOtp(config.Sms.OtpLength) if err != nil { return internalServerError("error generating otp").WithInternalError(err) } - user.ConfirmationToken = otp - - smsProvider, err := sms_provider.GetSmsProvider(*config) - if err != nil { - return err - } + *token = otp var message string if config.Sms.Template == "" { - message = fmt.Sprintf(defaultSmsMessage, user.ConfirmationToken) + message = fmt.Sprintf(defaultSmsMessage, *token) } else { - message = strings.Replace(config.Sms.Template, "{{ .Code }}", user.ConfirmationToken, -1) + message = strings.Replace(config.Sms.Template, "{{ .Code }}", *token, -1) } if serr := smsProvider.SendSms(phone, message); serr != nil { - user.ConfirmationToken = oldToken + *token = oldToken return serr } now := time.Now() - user.ConfirmationSentAt = &now + sentAt = &now - return errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") + return errors.Wrap(tx.UpdateOnly(user, tokenDbField, sentAtDbField), "Database error updating user for confirmation") } diff --git a/api/phone_test.go b/api/phone_test.go new file mode 100644 index 0000000000..3e959b981c --- /dev/null +++ b/api/phone_test.go @@ -0,0 +1,99 @@ +package api + +import ( + "context" + "testing" + + "github.com/gofrs/uuid" + "github.com/netlify/gotrue/conf" + "github.com/netlify/gotrue/models" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type PhoneTestSuite struct { + suite.Suite + API *API + Config *conf.Configuration + + instanceID uuid.UUID +} + +type TestSmsProvider struct { + mock.Mock +} + +func (t TestSmsProvider) SendSms(phone string, message string) error { + return nil +} + +func TestPhone(t *testing.T) { + api, config, instanceID, err := setupAPIForTestForInstance() + require.NoError(t, err) + + ts := &PhoneTestSuite{ + API: api, + Config: config, + instanceID: instanceID, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *PhoneTestSuite) SetupTest() { + models.TruncateAll(ts.API.db) + + // Create user + u, err := models.NewUser(ts.instanceID, "", "password", ts.Config.JWT.Aud, nil) + u.Phone = "123456789" + require.NoError(ts.T(), err, "Error creating test user model") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") +} + +func (ts *PhoneTestSuite) TestValidateE164Format() { + isValid := ts.API.validateE164Format("0123456789") + assert.Equal(ts.T(), false, isValid) +} + +func (ts *PhoneTestSuite) TestFormatPhoneNumber() { + actual := ts.API.formatPhoneNumber("+1 23456789 ") + assert.Equal(ts.T(), "123456789", actual) +} + +func (ts *PhoneTestSuite) TestSendPhoneConfirmation() { + u, err := models.FindUserByPhoneAndAudience(ts.API.db, ts.instanceID, "123456789", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + ctx, err := WithInstanceConfig(context.Background(), ts.Config, ts.instanceID) + require.NoError(ts.T(), err) + cases := []struct { + desc string + otpType string + expected error + }{ + { + "send confirmation otp", + phoneConfirmationOtp, + nil, + }, + { + "send phone_change otp", + phoneChangeOtp, + nil, + }, + { + "send invalid otp type ", + "invalid otp type", + internalServerError("invalid otp type"), + }, + } + + for _, c := range cases { + ts.Run(c.desc, func() { + err = ts.API.sendPhoneConfirmation(ctx, ts.API.db, u, "123456789", c.otpType, TestSmsProvider{}) + require.Equal(ts.T(), c.expected, err) + }) + } +} diff --git a/api/signup.go b/api/signup.go index a35d1b716f..270f6f9489 100644 --- a/api/signup.go +++ b/api/signup.go @@ -8,6 +8,7 @@ import ( "time" "github.com/gofrs/uuid" + "github.com/netlify/gotrue/api/sms_provider" "github.com/netlify/gotrue/metering" "github.com/netlify/gotrue/models" "github.com/netlify/gotrue/storage" @@ -159,7 +160,11 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { }); terr != nil { return terr } - if terr = a.sendPhoneConfirmation(ctx, tx, user, params.Phone); terr != nil { + smsProvider, err := sms_provider.GetSmsProvider(*config) + if err != nil { + return err + } + if terr = a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider); terr != nil { return badRequestError("Error sending confirmation sms: %v", terr) } } diff --git a/api/user.go b/api/user.go index 56296378e5..53fe2527e5 100644 --- a/api/user.go +++ b/api/user.go @@ -5,6 +5,7 @@ import ( "net/http" "github.com/gofrs/uuid" + "github.com/netlify/gotrue/api/sms_provider" "github.com/netlify/gotrue/models" "github.com/netlify/gotrue/storage" ) @@ -124,6 +125,30 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { } } + if params.Phone != "" { + params.Phone = a.formatPhoneNumber(params.Phone) + if isValid := a.validateE164Format(params.Phone); !isValid { + return unprocessableEntityError("Invalid phone number format") + } + var exists bool + if exists, terr = models.IsDuplicatedPhone(tx, instanceID, params.Phone, user.Aud); terr != nil { + return internalServerError("Database error checking phone").WithInternalError(terr) + } else if exists { + return unprocessableEntityError(DuplicatePhoneMsg) + } + if config.Sms.Autoconfirm { + return user.UpdatePhone(tx, params.Phone) + } else { + smsProvider, terr := sms_provider.GetSmsProvider(*config) + if terr != nil { + return terr + } + if terr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneChangeOtp, smsProvider); terr != nil { + return internalServerError("Error sending phone change otp").WithInternalError(terr) + } + } + } + if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserModifiedAction, nil); terr != nil { return internalServerError("Error recording audit log entry").WithInternalError(terr) } diff --git a/api/user_test.go b/api/user_test.go index 8d8566b1e7..b14d511544 100644 --- a/api/user_test.go +++ b/api/user_test.go @@ -43,6 +43,7 @@ func (ts *UserTestSuite) SetupTest() { // Create user u, err := models.NewUser(ts.instanceID, "test@example.com", "password", ts.Config.JWT.Aud, nil) + u.Phone = "123456789" require.NoError(ts.T(), err, "Error creating test user model") require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user") } @@ -72,7 +73,7 @@ func (ts *UserTestSuite) TestUserUpdateEmail() { "User doesn't have an existing email", map[string]string{ "email": "", - "phone": "123456789", + "phone": "", }, false, http.StatusOK, @@ -110,9 +111,9 @@ func (ts *UserTestSuite) TestUserUpdateEmail() { ts.Run(c.desc, func() { u, err := models.NewUser(ts.instanceID, "", "", ts.Config.JWT.Aud, nil) require.NoError(ts.T(), err, "Error creating test user model") - require.NoError(ts.T(), ts.API.db.Create(u), "Error saving test user") require.NoError(ts.T(), u.SetEmail(ts.API.db, c.userData["email"]), "Error setting user email") require.NoError(ts.T(), u.SetPhone(ts.API.db, c.userData["phone"]), "Error setting user phone") + require.NoError(ts.T(), ts.API.db.Create(u), "Error saving test user") token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) require.NoError(ts.T(), err, "Error generating access token") @@ -132,6 +133,53 @@ func (ts *UserTestSuite) TestUserUpdateEmail() { }) } +} +func (ts *UserTestSuite) TestUserUpdatePhoneAutoconfirmEnabled() { + u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud) + require.NoError(ts.T(), err) + + cases := []struct { + desc string + userData map[string]string + expectedCode int + }{ + { + "New phone number is the same as current phone number", + map[string]string{ + "phone": "123456789", + }, + http.StatusUnprocessableEntity, + }, + { + "New phone number is different from current phone number", + map[string]string{ + "phone": "234567890", + }, + http.StatusOK, + }, + } + + ts.Config.Sms.Autoconfirm = true + + for _, c := range cases { + ts.Run(c.desc, func() { + token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret) + require.NoError(ts.T(), err, "Error generating access token") + + var buffer bytes.Buffer + require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{ + "phone": c.userData["phone"], + })) + req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + + w := httptest.NewRecorder() + ts.API.handler.ServeHTTP(w, req) + require.Equal(ts.T(), c.expectedCode, w.Code) + }) + } + } func (ts *UserTestSuite) TestUserUpdatePassword() { diff --git a/api/verify.go b/api/verify.go index 06bfe05a2a..5150cde953 100644 --- a/api/verify.go +++ b/api/verify.go @@ -26,6 +26,7 @@ const ( magicLinkVerification = "magiclink" emailChangeVerification = "email_change" smsVerification = "sms" + phoneChangeVerification = "phone_change" ) const ( @@ -98,8 +99,8 @@ func (a *API) Verify(w http.ResponseWriter, r *http.Request) error { http.Redirect(w, r, rurl, http.StatusSeeOther) return nil } - case smsVerification: - user, terr = a.smsVerify(ctx, tx, user) + case smsVerification, phoneChangeVerification: + user, terr = a.smsVerify(ctx, tx, user, params.Type) default: return unprocessableEntityError("Verify requires a verification type") } @@ -231,7 +232,7 @@ func (a *API) recoverVerify(ctx context.Context, conn *storage.Connection, user return user, nil } -func (a *API) smsVerify(ctx context.Context, conn *storage.Connection, user *models.User) (*models.User, error) { +func (a *API) smsVerify(ctx context.Context, conn *storage.Connection, user *models.User, otpType string) (*models.User, error) { instanceID := getInstanceID(ctx) config := a.getConfig(ctx) @@ -245,8 +246,14 @@ func (a *API) smsVerify(ctx context.Context, conn *storage.Connection, user *mod return terr } - if terr = user.ConfirmPhone(tx); terr != nil { - return internalServerError("Error confirming user").WithInternalError(terr) + if otpType == smsVerification { + if terr = user.ConfirmPhone(tx); terr != nil { + return internalServerError("Error confirming user").WithInternalError(terr) + } + } else if otpType == phoneChangeVerification { + if terr = user.ConfirmPhoneChange(tx); terr != nil { + return internalServerError("Error confirming user").WithInternalError(terr) + } } return nil }) @@ -340,7 +347,7 @@ func (a *API) verifyUserAndToken(ctx context.Context, conn *storage.Connection, case emailChangeVerification: user, err = models.FindUserByEmailChangeToken(conn, params.Token) } - } else if params.Type == smsVerification { + } else if params.Type == smsVerification || params.Type == phoneChangeVerification { if params.Phone == "" { return nil, unprocessableEntityError("Sms Verification requires a phone number") } @@ -383,6 +390,8 @@ func (a *API) verifyUserAndToken(ctx context.Context, conn *storage.Connection, user.EmailChangeConfirmStatus = zeroConfirmation err = conn.UpdateOnly(user, "email_change_confirm_status") } + case phoneChangeVerification: + isValid = isOtpValid(params.Token, user.PhoneChangeToken, user.PhoneChangeSentAt.Add(smsOtpExpiresAt)) case smsVerification: isValid = isOtpValid(params.Token, user.ConfirmationToken, user.ConfirmationSentAt.Add(smsOtpExpiresAt)) } @@ -399,5 +408,6 @@ func isOtpValid(actual, expected string, expiresAt time.Time) bool { } func isUrlVerification(params *VerifyParams) bool { - return params.Type != smsVerification && params.Email == "" + isPhoneVerification := params.Type == smsVerification || params.Type == phoneChangeVerification + return !isPhoneVerification && params.Email == "" } diff --git a/api/verify_test.go b/api/verify_test.go index 65c4dc7fd9..597fbd3b6d 100644 --- a/api/verify_test.go +++ b/api/verify_test.go @@ -193,8 +193,10 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { u, err := models.FindUserByPhoneAndAudience(ts.API.db, ts.instanceID, "12345678", ts.Config.JWT.Aud) require.NoError(ts.T(), err) u.ConfirmationToken = "123456" + u.PhoneChangeToken = "123456" sentTime := time.Now().Add(-48 * time.Hour) u.ConfirmationSentAt = &sentTime + u.PhoneChangeSentAt = &sentTime require.NoError(ts.T(), ts.API.db.Update(u)) type ResponseBody struct { @@ -233,6 +235,16 @@ func (ts *VerifyTestSuite) TestInvalidOtp() { }, expected: expectedResponse, }, + { + desc: "Invalid Phone Change OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": phoneChangeVerification, + "token": "invalid_otp", + "phone": u.GetPhone(), + }, + expected: expectedResponse, + }, { desc: "Invalid Email OTP", sentTime: time.Now(), @@ -588,6 +600,16 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() { code: http.StatusSeeOther, }, }, + { + desc: "Valid Phone Change OTP", + sentTime: time.Now(), + body: map[string]interface{}{ + "type": phoneChangeVerification, + "token": "123456", + "phone": "12345678", + }, + expected: expectedResponse, + }, } for _, c := range cases { @@ -596,9 +618,11 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() { u.ConfirmationSentAt = &c.sentTime u.RecoverySentAt = &c.sentTime u.EmailChangeSentAt = &c.sentTime - u.ConfirmationToken, _ = c.body["token"].(string) - u.RecoveryToken, _ = c.body["token"].(string) - u.EmailChangeTokenCurrent, _ = c.body["token"].(string) + u.PhoneChangeSentAt = &c.sentTime + u.ConfirmationToken = c.body["token"].(string) + u.RecoveryToken = c.body["token"].(string) + u.EmailChangeTokenCurrent = c.body["token"].(string) + u.PhoneChangeToken = c.body["token"].(string) require.NoError(ts.T(), ts.API.db.Update(u)) var buffer bytes.Buffer diff --git a/go.sum b/go.sum index 2fd375e267..f248244695 100644 --- a/go.sum +++ b/go.sum @@ -539,6 +539,7 @@ github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/y github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=