Skip to content

Commit

Permalink
fix: allow user to update phone number (supabase#421)
Browse files Browse the repository at this point in the history
* refactor: sendPhoneConfirmation should accept an otpType and smsProvider

* fix: allow update phone for users endpoint

* fix: add phone change verification method
  • Loading branch information
kangmingtay authored and LashaJini committed Nov 15, 2024
1 parent 9c92569 commit 9403e25
Show file tree
Hide file tree
Showing 10 changed files with 263 additions and 29 deletions.
1 change: 1 addition & 0 deletions api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
// Common error messages during signup flow
var (
DuplicateEmailMsg = "A user with this email address has already been registered"
DuplicatePhoneMsg = "A user with this phone number has already been registered"
UserExistsError error = errors.New("User already exists")
)

Expand Down
8 changes: 6 additions & 2 deletions api/otp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"strings"

"github.com/netlify/gotrue/api/sms_provider"
"github.com/netlify/gotrue/models"
"github.com/netlify/gotrue/storage"
"github.com/sethvargo/go-password/password"
Expand Down Expand Up @@ -103,8 +104,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
if err := models.NewAuditLogEntry(tx, instanceID, user, models.UserRecoveryRequestedAction, nil); err != nil {
return err
}

if err := a.sendPhoneConfirmation(ctx, tx, user, params.Phone); err != nil {
smsProvider, err := sms_provider.GetSmsProvider(*config)
if err != nil {
return err
}
if err := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider); err != nil {
return badRequestError("Error sending sms otp: %v", err)
}
return nil
Expand Down
45 changes: 31 additions & 14 deletions api/phone.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ import (
const e164Format = `^[1-9]\d{1,14}$`
const defaultSmsMessage = "Your code is %v"

const (
phoneChangeOtp = "phone_change"
phoneConfirmationOtp = "confirmation"
)

// validateE165Format checks if phone number follows the E.164 format
func (a *API) validateE164Format(phone string) bool {
// match should never fail as long as regexp is valid
Expand All @@ -29,39 +34,51 @@ func (a *API) formatPhoneNumber(phone string) string {
return strings.ReplaceAll(strings.Trim(phone, "+"), " ", "")
}

func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection, user *models.User, phone string) error {
// sendPhoneConfirmation sends an otp to the user's phone number
func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider) error {
config := a.getConfig(ctx)

if user.ConfirmationSentAt != nil && !user.ConfirmationSentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) {
var token *string
var sentAt *time.Time
var tokenDbField, sentAtDbField string

if otpType == phoneConfirmationOtp {
token = &user.ConfirmationToken
sentAt = user.ConfirmationSentAt
tokenDbField, sentAtDbField = "confirmation_token", "confirmation_sent_at"
} else if otpType == phoneChangeOtp {
token = &user.PhoneChangeToken
sentAt = user.PhoneChangeSentAt
tokenDbField, sentAtDbField = "phone_change_token", "phone_change_sent_at"
} else {
return internalServerError("invalid otp type")
}

if sentAt != nil && !sentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) {
return MaxFrequencyLimitError
}

oldToken := user.ConfirmationToken
oldToken := *token
otp, err := crypto.GenerateOtp(config.Sms.OtpLength)
if err != nil {
return internalServerError("error generating otp").WithInternalError(err)
}
user.ConfirmationToken = otp

smsProvider, err := sms_provider.GetSmsProvider(*config)
if err != nil {
return err
}
*token = otp

var message string
if config.Sms.Template == "" {
message = fmt.Sprintf(defaultSmsMessage, user.ConfirmationToken)
message = fmt.Sprintf(defaultSmsMessage, *token)
} else {
message = strings.Replace(config.Sms.Template, "{{ .Code }}", user.ConfirmationToken, -1)
message = strings.Replace(config.Sms.Template, "{{ .Code }}", *token, -1)
}

if serr := smsProvider.SendSms(phone, message); serr != nil {
user.ConfirmationToken = oldToken
*token = oldToken
return serr
}

now := time.Now()
user.ConfirmationSentAt = &now
sentAt = &now

return errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation")
return errors.Wrap(tx.UpdateOnly(user, tokenDbField, sentAtDbField), "Database error updating user for confirmation")
}
99 changes: 99 additions & 0 deletions api/phone_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package api

import (
"context"
"testing"

"github.com/gofrs/uuid"
"github.com/netlify/gotrue/conf"
"github.com/netlify/gotrue/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)

type PhoneTestSuite struct {
suite.Suite
API *API
Config *conf.Configuration

instanceID uuid.UUID
}

type TestSmsProvider struct {
mock.Mock
}

func (t TestSmsProvider) SendSms(phone string, message string) error {
return nil
}

func TestPhone(t *testing.T) {
api, config, instanceID, err := setupAPIForTestForInstance()
require.NoError(t, err)

ts := &PhoneTestSuite{
API: api,
Config: config,
instanceID: instanceID,
}
defer api.db.Close()

suite.Run(t, ts)
}

func (ts *PhoneTestSuite) SetupTest() {
models.TruncateAll(ts.API.db)

// Create user
u, err := models.NewUser(ts.instanceID, "", "password", ts.Config.JWT.Aud, nil)
u.Phone = "123456789"
require.NoError(ts.T(), err, "Error creating test user model")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")
}

func (ts *PhoneTestSuite) TestValidateE164Format() {
isValid := ts.API.validateE164Format("0123456789")
assert.Equal(ts.T(), false, isValid)
}

func (ts *PhoneTestSuite) TestFormatPhoneNumber() {
actual := ts.API.formatPhoneNumber("+1 23456789 ")
assert.Equal(ts.T(), "123456789", actual)
}

func (ts *PhoneTestSuite) TestSendPhoneConfirmation() {
u, err := models.FindUserByPhoneAndAudience(ts.API.db, ts.instanceID, "123456789", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
ctx, err := WithInstanceConfig(context.Background(), ts.Config, ts.instanceID)
require.NoError(ts.T(), err)
cases := []struct {
desc string
otpType string
expected error
}{
{
"send confirmation otp",
phoneConfirmationOtp,
nil,
},
{
"send phone_change otp",
phoneChangeOtp,
nil,
},
{
"send invalid otp type ",
"invalid otp type",
internalServerError("invalid otp type"),
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
err = ts.API.sendPhoneConfirmation(ctx, ts.API.db, u, "123456789", c.otpType, TestSmsProvider{})
require.Equal(ts.T(), c.expected, err)
})
}
}
7 changes: 6 additions & 1 deletion api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/gofrs/uuid"
"github.com/netlify/gotrue/api/sms_provider"
"github.com/netlify/gotrue/metering"
"github.com/netlify/gotrue/models"
"github.com/netlify/gotrue/storage"
Expand Down Expand Up @@ -159,7 +160,11 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
}); terr != nil {
return terr
}
if terr = a.sendPhoneConfirmation(ctx, tx, user, params.Phone); terr != nil {
smsProvider, err := sms_provider.GetSmsProvider(*config)
if err != nil {
return err
}
if terr = a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider); terr != nil {
return badRequestError("Error sending confirmation sms: %v", terr)
}
}
Expand Down
25 changes: 25 additions & 0 deletions api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"

"github.com/gofrs/uuid"
"github.com/netlify/gotrue/api/sms_provider"
"github.com/netlify/gotrue/models"
"github.com/netlify/gotrue/storage"
)
Expand Down Expand Up @@ -124,6 +125,30 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
}
}

if params.Phone != "" {
params.Phone = a.formatPhoneNumber(params.Phone)
if isValid := a.validateE164Format(params.Phone); !isValid {
return unprocessableEntityError("Invalid phone number format")
}
var exists bool
if exists, terr = models.IsDuplicatedPhone(tx, instanceID, params.Phone, user.Aud); terr != nil {
return internalServerError("Database error checking phone").WithInternalError(terr)
} else if exists {
return unprocessableEntityError(DuplicatePhoneMsg)
}
if config.Sms.Autoconfirm {
return user.UpdatePhone(tx, params.Phone)
} else {
smsProvider, terr := sms_provider.GetSmsProvider(*config)
if terr != nil {
return terr
}
if terr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneChangeOtp, smsProvider); terr != nil {
return internalServerError("Error sending phone change otp").WithInternalError(terr)
}
}
}

if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserModifiedAction, nil); terr != nil {
return internalServerError("Error recording audit log entry").WithInternalError(terr)
}
Expand Down
52 changes: 50 additions & 2 deletions api/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func (ts *UserTestSuite) SetupTest() {

// Create user
u, err := models.NewUser(ts.instanceID, "test@example.com", "password", ts.Config.JWT.Aud, nil)
u.Phone = "123456789"
require.NoError(ts.T(), err, "Error creating test user model")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")
}
Expand Down Expand Up @@ -72,7 +73,7 @@ func (ts *UserTestSuite) TestUserUpdateEmail() {
"User doesn't have an existing email",
map[string]string{
"email": "",
"phone": "123456789",
"phone": "",
},
false,
http.StatusOK,
Expand Down Expand Up @@ -110,9 +111,9 @@ func (ts *UserTestSuite) TestUserUpdateEmail() {
ts.Run(c.desc, func() {
u, err := models.NewUser(ts.instanceID, "", "", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving test user")
require.NoError(ts.T(), u.SetEmail(ts.API.db, c.userData["email"]), "Error setting user email")
require.NoError(ts.T(), u.SetPhone(ts.API.db, c.userData["phone"]), "Error setting user phone")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving test user")

token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
require.NoError(ts.T(), err, "Error generating access token")
Expand All @@ -132,6 +133,53 @@ func (ts *UserTestSuite) TestUserUpdateEmail() {
})
}

}
func (ts *UserTestSuite) TestUserUpdatePhoneAutoconfirmEnabled() {
u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

cases := []struct {
desc string
userData map[string]string
expectedCode int
}{
{
"New phone number is the same as current phone number",
map[string]string{
"phone": "123456789",
},
http.StatusUnprocessableEntity,
},
{
"New phone number is different from current phone number",
map[string]string{
"phone": "234567890",
},
http.StatusOK,
},
}

ts.Config.Sms.Autoconfirm = true

for _, c := range cases {
ts.Run(c.desc, func() {
token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
require.NoError(ts.T(), err, "Error generating access token")

var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"phone": c.userData["phone"],
}))
req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), c.expectedCode, w.Code)
})
}

}

func (ts *UserTestSuite) TestUserUpdatePassword() {
Expand Down
Loading

0 comments on commit 9403e25

Please sign in to comment.