Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: allow user to update phone number #421

Merged
merged 3 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
// Common error messages during signup flow
var (
DuplicateEmailMsg = "A user with this email address has already been registered"
DuplicatePhoneMsg = "A user with this phone number has already been registered"
UserExistsError error = errors.New("User already exists")
)

Expand Down
8 changes: 6 additions & 2 deletions api/otp.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"net/http"
"strings"

"github.com/netlify/gotrue/api/sms_provider"
"github.com/netlify/gotrue/models"
"github.com/netlify/gotrue/storage"
"github.com/sethvargo/go-password/password"
Expand Down Expand Up @@ -103,8 +104,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
if err := models.NewAuditLogEntry(tx, instanceID, user, models.UserRecoveryRequestedAction, nil); err != nil {
return err
}

if err := a.sendPhoneConfirmation(ctx, tx, user, params.Phone); err != nil {
smsProvider, err := sms_provider.GetSmsProvider(*config)
if err != nil {
return err
}
if err := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider); err != nil {
return badRequestError("Error sending sms otp: %v", err)
}
return nil
Expand Down
45 changes: 31 additions & 14 deletions api/phone.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ import (
const e164Format = `^[1-9]\d{1,14}$`
const defaultSmsMessage = "Your code is %v"

const (
phoneChangeOtp = "phone_change"
phoneConfirmationOtp = "confirmation"
)

// validateE165Format checks if phone number follows the E.164 format
func (a *API) validateE164Format(phone string) bool {
// match should never fail as long as regexp is valid
Expand All @@ -29,39 +34,51 @@ func (a *API) formatPhoneNumber(phone string) string {
return strings.ReplaceAll(strings.Trim(phone, "+"), " ", "")
}

func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection, user *models.User, phone string) error {
// sendPhoneConfirmation sends an otp to the user's phone number
func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider) error {
Copy link
Member Author

@kangmingtay kangmingtay Mar 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed the function signature here to take in an otpType and a smsProvider. The smsProvider was added to the function params to make it easier to test this function with a mock sms provider that implements a stub SendSms method. See TestSendPhoneConfirmation.

config := a.getConfig(ctx)

if user.ConfirmationSentAt != nil && !user.ConfirmationSentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) {
var token *string
var sentAt *time.Time
var tokenDbField, sentAtDbField string

if otpType == phoneConfirmationOtp {
token = &user.ConfirmationToken
sentAt = user.ConfirmationSentAt
tokenDbField, sentAtDbField = "confirmation_token", "confirmation_sent_at"
} else if otpType == phoneChangeOtp {
token = &user.PhoneChangeToken
sentAt = user.PhoneChangeSentAt
tokenDbField, sentAtDbField = "phone_change_token", "phone_change_sent_at"
} else {
return internalServerError("invalid otp type")
}

if sentAt != nil && !sentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) {
J0 marked this conversation as resolved.
Show resolved Hide resolved
return MaxFrequencyLimitError
}

oldToken := user.ConfirmationToken
oldToken := *token
otp, err := crypto.GenerateOtp(config.Sms.OtpLength)
if err != nil {
return internalServerError("error generating otp").WithInternalError(err)
}
user.ConfirmationToken = otp

smsProvider, err := sms_provider.GetSmsProvider(*config)
if err != nil {
return err
}
*token = otp

var message string
if config.Sms.Template == "" {
message = fmt.Sprintf(defaultSmsMessage, user.ConfirmationToken)
message = fmt.Sprintf(defaultSmsMessage, *token)
} else {
message = strings.Replace(config.Sms.Template, "{{ .Code }}", user.ConfirmationToken, -1)
message = strings.Replace(config.Sms.Template, "{{ .Code }}", *token, -1)
}

if serr := smsProvider.SendSms(phone, message); serr != nil {
user.ConfirmationToken = oldToken
*token = oldToken
return serr
}

now := time.Now()
user.ConfirmationSentAt = &now
sentAt = &now

return errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation")
return errors.Wrap(tx.UpdateOnly(user, tokenDbField, sentAtDbField), "Database error updating user for confirmation")
}
99 changes: 99 additions & 0 deletions api/phone_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
package api

import (
"context"
"testing"

"github.com/gofrs/uuid"
"github.com/netlify/gotrue/conf"
"github.com/netlify/gotrue/models"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)

type PhoneTestSuite struct {
suite.Suite
API *API
Config *conf.Configuration

instanceID uuid.UUID
}

type TestSmsProvider struct {
mock.Mock
}

func (t TestSmsProvider) SendSms(phone string, message string) error {
return nil
}

func TestPhone(t *testing.T) {
api, config, instanceID, err := setupAPIForTestForInstance()
require.NoError(t, err)

ts := &PhoneTestSuite{
API: api,
Config: config,
instanceID: instanceID,
}
defer api.db.Close()

suite.Run(t, ts)
}

func (ts *PhoneTestSuite) SetupTest() {
models.TruncateAll(ts.API.db)

// Create user
u, err := models.NewUser(ts.instanceID, "", "password", ts.Config.JWT.Aud, nil)
u.Phone = "123456789"
require.NoError(ts.T(), err, "Error creating test user model")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")
}

func (ts *PhoneTestSuite) TestValidateE164Format() {
isValid := ts.API.validateE164Format("0123456789")
assert.Equal(ts.T(), false, isValid)
}

func (ts *PhoneTestSuite) TestFormatPhoneNumber() {
actual := ts.API.formatPhoneNumber("+1 23456789 ")
assert.Equal(ts.T(), "123456789", actual)
}

func (ts *PhoneTestSuite) TestSendPhoneConfirmation() {
u, err := models.FindUserByPhoneAndAudience(ts.API.db, ts.instanceID, "123456789", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
ctx, err := WithInstanceConfig(context.Background(), ts.Config, ts.instanceID)
require.NoError(ts.T(), err)
cases := []struct {
desc string
otpType string
expected error
}{
{
"send confirmation otp",
phoneConfirmationOtp,
nil,
},
{
"send phone_change otp",
phoneChangeOtp,
nil,
},
{
"send invalid otp type ",
"invalid otp type",
internalServerError("invalid otp type"),
},
}

for _, c := range cases {
ts.Run(c.desc, func() {
err = ts.API.sendPhoneConfirmation(ctx, ts.API.db, u, "123456789", c.otpType, TestSmsProvider{})
require.Equal(ts.T(), c.expected, err)
})
}
}
7 changes: 6 additions & 1 deletion api/signup.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"time"

"github.com/gofrs/uuid"
"github.com/netlify/gotrue/api/sms_provider"
"github.com/netlify/gotrue/metering"
"github.com/netlify/gotrue/models"
"github.com/netlify/gotrue/storage"
Expand Down Expand Up @@ -159,7 +160,11 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
}); terr != nil {
return terr
}
if terr = a.sendPhoneConfirmation(ctx, tx, user, params.Phone); terr != nil {
smsProvider, err := sms_provider.GetSmsProvider(*config)
if err != nil {
return err
}
if terr = a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider); terr != nil {
return badRequestError("Error sending confirmation sms: %v", terr)
}
}
Expand Down
25 changes: 25 additions & 0 deletions api/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net/http"

"github.com/gofrs/uuid"
"github.com/netlify/gotrue/api/sms_provider"
"github.com/netlify/gotrue/models"
"github.com/netlify/gotrue/storage"
)
Expand Down Expand Up @@ -124,6 +125,30 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
}
}

if params.Phone != "" {
params.Phone = a.formatPhoneNumber(params.Phone)
if isValid := a.validateE164Format(params.Phone); !isValid {
return unprocessableEntityError("Invalid phone number format")
}
var exists bool
if exists, terr = models.IsDuplicatedPhone(tx, instanceID, params.Phone, user.Aud); terr != nil {
return internalServerError("Database error checking phone").WithInternalError(terr)
} else if exists {
return unprocessableEntityError(DuplicatePhoneMsg)
}
if config.Sms.Autoconfirm {
return user.UpdatePhone(tx, params.Phone)
} else {
smsProvider, terr := sms_provider.GetSmsProvider(*config)
if terr != nil {
return terr
}
if terr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneChangeOtp, smsProvider); terr != nil {
return internalServerError("Error sending phone change otp").WithInternalError(terr)
}
}
}

if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserModifiedAction, nil); terr != nil {
return internalServerError("Error recording audit log entry").WithInternalError(terr)
}
Expand Down
52 changes: 50 additions & 2 deletions api/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ func (ts *UserTestSuite) SetupTest() {

// Create user
u, err := models.NewUser(ts.instanceID, "test@example.com", "password", ts.Config.JWT.Aud, nil)
u.Phone = "123456789"
require.NoError(ts.T(), err, "Error creating test user model")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")
}
Expand Down Expand Up @@ -72,7 +73,7 @@ func (ts *UserTestSuite) TestUserUpdateEmail() {
"User doesn't have an existing email",
map[string]string{
"email": "",
"phone": "123456789",
"phone": "",
},
false,
http.StatusOK,
Expand Down Expand Up @@ -110,9 +111,9 @@ func (ts *UserTestSuite) TestUserUpdateEmail() {
ts.Run(c.desc, func() {
u, err := models.NewUser(ts.instanceID, "", "", ts.Config.JWT.Aud, nil)
require.NoError(ts.T(), err, "Error creating test user model")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving test user")
require.NoError(ts.T(), u.SetEmail(ts.API.db, c.userData["email"]), "Error setting user email")
require.NoError(ts.T(), u.SetPhone(ts.API.db, c.userData["phone"]), "Error setting user phone")
require.NoError(ts.T(), ts.API.db.Create(u), "Error saving test user")

token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
require.NoError(ts.T(), err, "Error generating access token")
Expand All @@ -132,6 +133,53 @@ func (ts *UserTestSuite) TestUserUpdateEmail() {
})
}

}
func (ts *UserTestSuite) TestUserUpdatePhoneAutoconfirmEnabled() {
u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)

cases := []struct {
desc string
userData map[string]string
expectedCode int
}{
{
"New phone number is the same as current phone number",
map[string]string{
"phone": "123456789",
},
http.StatusUnprocessableEntity,
},
{
"New phone number is different from current phone number",
map[string]string{
"phone": "234567890",
},
http.StatusOK,
},
}

ts.Config.Sms.Autoconfirm = true

for _, c := range cases {
ts.Run(c.desc, func() {
token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
require.NoError(ts.T(), err, "Error generating access token")

var buffer bytes.Buffer
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
"phone": c.userData["phone"],
}))
req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))

w := httptest.NewRecorder()
ts.API.handler.ServeHTTP(w, req)
require.Equal(ts.T(), c.expectedCode, w.Code)
})
}

}

func (ts *UserTestSuite) TestUserUpdatePassword() {
Expand Down
Loading