From 9403e2597aa281ad268661d8c4a98bc9438bb6d2 Mon Sep 17 00:00:00 2001
From: Kang Ming <kang.ming1996@gmail.com>
Date: Tue, 22 Mar 2022 20:12:35 +0100
Subject: [PATCH] fix: allow user to update phone number (#421)

* refactor: sendPhoneConfirmation should accept an otpType and smsProvider

* fix: allow update phone for users endpoint

* fix: add phone change verification method
---
 api/errors.go      |  1 +
 api/otp.go         |  8 +++-
 api/phone.go       | 45 ++++++++++++++-------
 api/phone_test.go  | 99 ++++++++++++++++++++++++++++++++++++++++++++++
 api/signup.go      |  7 +++-
 api/user.go        | 25 ++++++++++++
 api/user_test.go   | 52 +++++++++++++++++++++++-
 api/verify.go      | 24 +++++++----
 api/verify_test.go | 30 ++++++++++++--
 go.sum             |  1 +
 10 files changed, 263 insertions(+), 29 deletions(-)
 create mode 100644 api/phone_test.go

diff --git a/api/errors.go b/api/errors.go
index f839357def..be77a7eea5 100644
--- a/api/errors.go
+++ b/api/errors.go
@@ -14,6 +14,7 @@ import (
 // Common error messages during signup flow
 var (
 	DuplicateEmailMsg       = "A user with this email address has already been registered"
+	DuplicatePhoneMsg       = "A user with this phone number has already been registered"
 	UserExistsError   error = errors.New("User already exists")
 )
 
diff --git a/api/otp.go b/api/otp.go
index 6b8a86b812..ad2955b415 100644
--- a/api/otp.go
+++ b/api/otp.go
@@ -7,6 +7,7 @@ import (
 	"net/http"
 	"strings"
 
+	"github.com/netlify/gotrue/api/sms_provider"
 	"github.com/netlify/gotrue/models"
 	"github.com/netlify/gotrue/storage"
 	"github.com/sethvargo/go-password/password"
@@ -103,8 +104,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
 		if err := models.NewAuditLogEntry(tx, instanceID, user, models.UserRecoveryRequestedAction, nil); err != nil {
 			return err
 		}
-
-		if err := a.sendPhoneConfirmation(ctx, tx, user, params.Phone); err != nil {
+		smsProvider, err := sms_provider.GetSmsProvider(*config)
+		if err != nil {
+			return err
+		}
+		if err := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider); err != nil {
 			return badRequestError("Error sending sms otp: %v", err)
 		}
 		return nil
diff --git a/api/phone.go b/api/phone.go
index c363973319..6c107fcc11 100644
--- a/api/phone.go
+++ b/api/phone.go
@@ -17,6 +17,11 @@ import (
 const e164Format = `^[1-9]\d{1,14}$`
 const defaultSmsMessage = "Your code is %v"
 
+const (
+	phoneChangeOtp       = "phone_change"
+	phoneConfirmationOtp = "confirmation"
+)
+
 // validateE165Format checks if phone number follows the E.164 format
 func (a *API) validateE164Format(phone string) bool {
 	// match should never fail as long as regexp is valid
@@ -29,39 +34,51 @@ func (a *API) formatPhoneNumber(phone string) string {
 	return strings.ReplaceAll(strings.Trim(phone, "+"), " ", "")
 }
 
-func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection, user *models.User, phone string) error {
+// sendPhoneConfirmation sends an otp to the user's phone number
+func (a *API) sendPhoneConfirmation(ctx context.Context, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider) error {
 	config := a.getConfig(ctx)
 
-	if user.ConfirmationSentAt != nil && !user.ConfirmationSentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) {
+	var token *string
+	var sentAt *time.Time
+	var tokenDbField, sentAtDbField string
+
+	if otpType == phoneConfirmationOtp {
+		token = &user.ConfirmationToken
+		sentAt = user.ConfirmationSentAt
+		tokenDbField, sentAtDbField = "confirmation_token", "confirmation_sent_at"
+	} else if otpType == phoneChangeOtp {
+		token = &user.PhoneChangeToken
+		sentAt = user.PhoneChangeSentAt
+		tokenDbField, sentAtDbField = "phone_change_token", "phone_change_sent_at"
+	} else {
+		return internalServerError("invalid otp type")
+	}
+
+	if sentAt != nil && !sentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) {
 		return MaxFrequencyLimitError
 	}
 
-	oldToken := user.ConfirmationToken
+	oldToken := *token
 	otp, err := crypto.GenerateOtp(config.Sms.OtpLength)
 	if err != nil {
 		return internalServerError("error generating otp").WithInternalError(err)
 	}
-	user.ConfirmationToken = otp
-
-	smsProvider, err := sms_provider.GetSmsProvider(*config)
-	if err != nil {
-		return err
-	}
+	*token = otp
 
 	var message string
 	if config.Sms.Template == "" {
-		message = fmt.Sprintf(defaultSmsMessage, user.ConfirmationToken)
+		message = fmt.Sprintf(defaultSmsMessage, *token)
 	} else {
-		message = strings.Replace(config.Sms.Template, "{{ .Code }}", user.ConfirmationToken, -1)
+		message = strings.Replace(config.Sms.Template, "{{ .Code }}", *token, -1)
 	}
 
 	if serr := smsProvider.SendSms(phone, message); serr != nil {
-		user.ConfirmationToken = oldToken
+		*token = oldToken
 		return serr
 	}
 
 	now := time.Now()
-	user.ConfirmationSentAt = &now
+	sentAt = &now
 
-	return errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation")
+	return errors.Wrap(tx.UpdateOnly(user, tokenDbField, sentAtDbField), "Database error updating user for confirmation")
 }
diff --git a/api/phone_test.go b/api/phone_test.go
new file mode 100644
index 0000000000..3e959b981c
--- /dev/null
+++ b/api/phone_test.go
@@ -0,0 +1,99 @@
+package api
+
+import (
+	"context"
+	"testing"
+
+	"github.com/gofrs/uuid"
+	"github.com/netlify/gotrue/conf"
+	"github.com/netlify/gotrue/models"
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/mock"
+	"github.com/stretchr/testify/require"
+	"github.com/stretchr/testify/suite"
+)
+
+type PhoneTestSuite struct {
+	suite.Suite
+	API    *API
+	Config *conf.Configuration
+
+	instanceID uuid.UUID
+}
+
+type TestSmsProvider struct {
+	mock.Mock
+}
+
+func (t TestSmsProvider) SendSms(phone string, message string) error {
+	return nil
+}
+
+func TestPhone(t *testing.T) {
+	api, config, instanceID, err := setupAPIForTestForInstance()
+	require.NoError(t, err)
+
+	ts := &PhoneTestSuite{
+		API:        api,
+		Config:     config,
+		instanceID: instanceID,
+	}
+	defer api.db.Close()
+
+	suite.Run(t, ts)
+}
+
+func (ts *PhoneTestSuite) SetupTest() {
+	models.TruncateAll(ts.API.db)
+
+	// Create user
+	u, err := models.NewUser(ts.instanceID, "", "password", ts.Config.JWT.Aud, nil)
+	u.Phone = "123456789"
+	require.NoError(ts.T(), err, "Error creating test user model")
+	require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")
+}
+
+func (ts *PhoneTestSuite) TestValidateE164Format() {
+	isValid := ts.API.validateE164Format("0123456789")
+	assert.Equal(ts.T(), false, isValid)
+}
+
+func (ts *PhoneTestSuite) TestFormatPhoneNumber() {
+	actual := ts.API.formatPhoneNumber("+1 23456789 ")
+	assert.Equal(ts.T(), "123456789", actual)
+}
+
+func (ts *PhoneTestSuite) TestSendPhoneConfirmation() {
+	u, err := models.FindUserByPhoneAndAudience(ts.API.db, ts.instanceID, "123456789", ts.Config.JWT.Aud)
+	require.NoError(ts.T(), err)
+	ctx, err := WithInstanceConfig(context.Background(), ts.Config, ts.instanceID)
+	require.NoError(ts.T(), err)
+	cases := []struct {
+		desc     string
+		otpType  string
+		expected error
+	}{
+		{
+			"send confirmation otp",
+			phoneConfirmationOtp,
+			nil,
+		},
+		{
+			"send phone_change otp",
+			phoneChangeOtp,
+			nil,
+		},
+		{
+			"send invalid otp type ",
+			"invalid otp type",
+			internalServerError("invalid otp type"),
+		},
+	}
+
+	for _, c := range cases {
+		ts.Run(c.desc, func() {
+			err = ts.API.sendPhoneConfirmation(ctx, ts.API.db, u, "123456789", c.otpType, TestSmsProvider{})
+			require.Equal(ts.T(), c.expected, err)
+		})
+	}
+}
diff --git a/api/signup.go b/api/signup.go
index a35d1b716f..270f6f9489 100644
--- a/api/signup.go
+++ b/api/signup.go
@@ -8,6 +8,7 @@ import (
 	"time"
 
 	"github.com/gofrs/uuid"
+	"github.com/netlify/gotrue/api/sms_provider"
 	"github.com/netlify/gotrue/metering"
 	"github.com/netlify/gotrue/models"
 	"github.com/netlify/gotrue/storage"
@@ -159,7 +160,11 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error {
 				}); terr != nil {
 					return terr
 				}
-				if terr = a.sendPhoneConfirmation(ctx, tx, user, params.Phone); terr != nil {
+				smsProvider, err := sms_provider.GetSmsProvider(*config)
+				if err != nil {
+					return err
+				}
+				if terr = a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider); terr != nil {
 					return badRequestError("Error sending confirmation sms: %v", terr)
 				}
 			}
diff --git a/api/user.go b/api/user.go
index 56296378e5..53fe2527e5 100644
--- a/api/user.go
+++ b/api/user.go
@@ -5,6 +5,7 @@ import (
 	"net/http"
 
 	"github.com/gofrs/uuid"
+	"github.com/netlify/gotrue/api/sms_provider"
 	"github.com/netlify/gotrue/models"
 	"github.com/netlify/gotrue/storage"
 )
@@ -124,6 +125,30 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error {
 			}
 		}
 
+		if params.Phone != "" {
+			params.Phone = a.formatPhoneNumber(params.Phone)
+			if isValid := a.validateE164Format(params.Phone); !isValid {
+				return unprocessableEntityError("Invalid phone number format")
+			}
+			var exists bool
+			if exists, terr = models.IsDuplicatedPhone(tx, instanceID, params.Phone, user.Aud); terr != nil {
+				return internalServerError("Database error checking phone").WithInternalError(terr)
+			} else if exists {
+				return unprocessableEntityError(DuplicatePhoneMsg)
+			}
+			if config.Sms.Autoconfirm {
+				return user.UpdatePhone(tx, params.Phone)
+			} else {
+				smsProvider, terr := sms_provider.GetSmsProvider(*config)
+				if terr != nil {
+					return terr
+				}
+				if terr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneChangeOtp, smsProvider); terr != nil {
+					return internalServerError("Error sending phone change otp").WithInternalError(terr)
+				}
+			}
+		}
+
 		if terr = models.NewAuditLogEntry(tx, instanceID, user, models.UserModifiedAction, nil); terr != nil {
 			return internalServerError("Error recording audit log entry").WithInternalError(terr)
 		}
diff --git a/api/user_test.go b/api/user_test.go
index 8d8566b1e7..b14d511544 100644
--- a/api/user_test.go
+++ b/api/user_test.go
@@ -43,6 +43,7 @@ func (ts *UserTestSuite) SetupTest() {
 
 	// Create user
 	u, err := models.NewUser(ts.instanceID, "test@example.com", "password", ts.Config.JWT.Aud, nil)
+	u.Phone = "123456789"
 	require.NoError(ts.T(), err, "Error creating test user model")
 	require.NoError(ts.T(), ts.API.db.Create(u), "Error saving new test user")
 }
@@ -72,7 +73,7 @@ func (ts *UserTestSuite) TestUserUpdateEmail() {
 			"User doesn't have an existing email",
 			map[string]string{
 				"email": "",
-				"phone": "123456789",
+				"phone": "",
 			},
 			false,
 			http.StatusOK,
@@ -110,9 +111,9 @@ func (ts *UserTestSuite) TestUserUpdateEmail() {
 		ts.Run(c.desc, func() {
 			u, err := models.NewUser(ts.instanceID, "", "", ts.Config.JWT.Aud, nil)
 			require.NoError(ts.T(), err, "Error creating test user model")
-			require.NoError(ts.T(), ts.API.db.Create(u), "Error saving test user")
 			require.NoError(ts.T(), u.SetEmail(ts.API.db, c.userData["email"]), "Error setting user email")
 			require.NoError(ts.T(), u.SetPhone(ts.API.db, c.userData["phone"]), "Error setting user phone")
+			require.NoError(ts.T(), ts.API.db.Create(u), "Error saving test user")
 
 			token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
 			require.NoError(ts.T(), err, "Error generating access token")
@@ -132,6 +133,53 @@ func (ts *UserTestSuite) TestUserUpdateEmail() {
 		})
 	}
 
+}
+func (ts *UserTestSuite) TestUserUpdatePhoneAutoconfirmEnabled() {
+	u, err := models.FindUserByEmailAndAudience(ts.API.db, ts.instanceID, "test@example.com", ts.Config.JWT.Aud)
+	require.NoError(ts.T(), err)
+
+	cases := []struct {
+		desc         string
+		userData     map[string]string
+		expectedCode int
+	}{
+		{
+			"New phone number is the same as current phone number",
+			map[string]string{
+				"phone": "123456789",
+			},
+			http.StatusUnprocessableEntity,
+		},
+		{
+			"New phone number is different from current phone number",
+			map[string]string{
+				"phone": "234567890",
+			},
+			http.StatusOK,
+		},
+	}
+
+	ts.Config.Sms.Autoconfirm = true
+
+	for _, c := range cases {
+		ts.Run(c.desc, func() {
+			token, err := generateAccessToken(u, time.Second*time.Duration(ts.Config.JWT.Exp), ts.Config.JWT.Secret)
+			require.NoError(ts.T(), err, "Error generating access token")
+
+			var buffer bytes.Buffer
+			require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
+				"phone": c.userData["phone"],
+			}))
+			req := httptest.NewRequest(http.MethodPut, "http://localhost/user", &buffer)
+			req.Header.Set("Content-Type", "application/json")
+			req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
+
+			w := httptest.NewRecorder()
+			ts.API.handler.ServeHTTP(w, req)
+			require.Equal(ts.T(), c.expectedCode, w.Code)
+		})
+	}
+
 }
 
 func (ts *UserTestSuite) TestUserUpdatePassword() {
diff --git a/api/verify.go b/api/verify.go
index 06bfe05a2a..5150cde953 100644
--- a/api/verify.go
+++ b/api/verify.go
@@ -26,6 +26,7 @@ const (
 	magicLinkVerification   = "magiclink"
 	emailChangeVerification = "email_change"
 	smsVerification         = "sms"
+	phoneChangeVerification = "phone_change"
 )
 
 const (
@@ -98,8 +99,8 @@ func (a *API) Verify(w http.ResponseWriter, r *http.Request) error {
 				http.Redirect(w, r, rurl, http.StatusSeeOther)
 				return nil
 			}
-		case smsVerification:
-			user, terr = a.smsVerify(ctx, tx, user)
+		case smsVerification, phoneChangeVerification:
+			user, terr = a.smsVerify(ctx, tx, user, params.Type)
 		default:
 			return unprocessableEntityError("Verify requires a verification type")
 		}
@@ -231,7 +232,7 @@ func (a *API) recoverVerify(ctx context.Context, conn *storage.Connection, user
 	return user, nil
 }
 
-func (a *API) smsVerify(ctx context.Context, conn *storage.Connection, user *models.User) (*models.User, error) {
+func (a *API) smsVerify(ctx context.Context, conn *storage.Connection, user *models.User, otpType string) (*models.User, error) {
 	instanceID := getInstanceID(ctx)
 	config := a.getConfig(ctx)
 
@@ -245,8 +246,14 @@ func (a *API) smsVerify(ctx context.Context, conn *storage.Connection, user *mod
 			return terr
 		}
 
-		if terr = user.ConfirmPhone(tx); terr != nil {
-			return internalServerError("Error confirming user").WithInternalError(terr)
+		if otpType == smsVerification {
+			if terr = user.ConfirmPhone(tx); terr != nil {
+				return internalServerError("Error confirming user").WithInternalError(terr)
+			}
+		} else if otpType == phoneChangeVerification {
+			if terr = user.ConfirmPhoneChange(tx); terr != nil {
+				return internalServerError("Error confirming user").WithInternalError(terr)
+			}
 		}
 		return nil
 	})
@@ -340,7 +347,7 @@ func (a *API) verifyUserAndToken(ctx context.Context, conn *storage.Connection,
 		case emailChangeVerification:
 			user, err = models.FindUserByEmailChangeToken(conn, params.Token)
 		}
-	} else if params.Type == smsVerification {
+	} else if params.Type == smsVerification || params.Type == phoneChangeVerification {
 		if params.Phone == "" {
 			return nil, unprocessableEntityError("Sms Verification requires a phone number")
 		}
@@ -383,6 +390,8 @@ func (a *API) verifyUserAndToken(ctx context.Context, conn *storage.Connection,
 			user.EmailChangeConfirmStatus = zeroConfirmation
 			err = conn.UpdateOnly(user, "email_change_confirm_status")
 		}
+	case phoneChangeVerification:
+		isValid = isOtpValid(params.Token, user.PhoneChangeToken, user.PhoneChangeSentAt.Add(smsOtpExpiresAt))
 	case smsVerification:
 		isValid = isOtpValid(params.Token, user.ConfirmationToken, user.ConfirmationSentAt.Add(smsOtpExpiresAt))
 	}
@@ -399,5 +408,6 @@ func isOtpValid(actual, expected string, expiresAt time.Time) bool {
 }
 
 func isUrlVerification(params *VerifyParams) bool {
-	return params.Type != smsVerification && params.Email == ""
+	isPhoneVerification := params.Type == smsVerification || params.Type == phoneChangeVerification
+	return !isPhoneVerification && params.Email == ""
 }
diff --git a/api/verify_test.go b/api/verify_test.go
index 65c4dc7fd9..597fbd3b6d 100644
--- a/api/verify_test.go
+++ b/api/verify_test.go
@@ -193,8 +193,10 @@ func (ts *VerifyTestSuite) TestInvalidOtp() {
 	u, err := models.FindUserByPhoneAndAudience(ts.API.db, ts.instanceID, "12345678", ts.Config.JWT.Aud)
 	require.NoError(ts.T(), err)
 	u.ConfirmationToken = "123456"
+	u.PhoneChangeToken = "123456"
 	sentTime := time.Now().Add(-48 * time.Hour)
 	u.ConfirmationSentAt = &sentTime
+	u.PhoneChangeSentAt = &sentTime
 	require.NoError(ts.T(), ts.API.db.Update(u))
 
 	type ResponseBody struct {
@@ -233,6 +235,16 @@ func (ts *VerifyTestSuite) TestInvalidOtp() {
 			},
 			expected: expectedResponse,
 		},
+		{
+			desc:     "Invalid Phone Change OTP",
+			sentTime: time.Now(),
+			body: map[string]interface{}{
+				"type":  phoneChangeVerification,
+				"token": "invalid_otp",
+				"phone": u.GetPhone(),
+			},
+			expected: expectedResponse,
+		},
 		{
 			desc:     "Invalid Email OTP",
 			sentTime: time.Now(),
@@ -588,6 +600,16 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() {
 				code: http.StatusSeeOther,
 			},
 		},
+		{
+			desc:     "Valid Phone Change OTP",
+			sentTime: time.Now(),
+			body: map[string]interface{}{
+				"type":  phoneChangeVerification,
+				"token": "123456",
+				"phone": "12345678",
+			},
+			expected: expectedResponse,
+		},
 	}
 
 	for _, c := range cases {
@@ -596,9 +618,11 @@ func (ts *VerifyTestSuite) TestVerifyValidOtp() {
 			u.ConfirmationSentAt = &c.sentTime
 			u.RecoverySentAt = &c.sentTime
 			u.EmailChangeSentAt = &c.sentTime
-			u.ConfirmationToken, _ = c.body["token"].(string)
-			u.RecoveryToken, _ = c.body["token"].(string)
-			u.EmailChangeTokenCurrent, _ = c.body["token"].(string)
+			u.PhoneChangeSentAt = &c.sentTime
+			u.ConfirmationToken = c.body["token"].(string)
+			u.RecoveryToken = c.body["token"].(string)
+			u.EmailChangeTokenCurrent = c.body["token"].(string)
+			u.PhoneChangeToken = c.body["token"].(string)
 			require.NoError(ts.T(), ts.API.db.Update(u))
 
 			var buffer bytes.Buffer
diff --git a/go.sum b/go.sum
index 2fd375e267..f248244695 100644
--- a/go.sum
+++ b/go.sum
@@ -539,6 +539,7 @@ github.com/spf13/viper v1.4.0/go.mod h1:PTJ7Z/lr49W6bUbkmS1V3by4uWynFiR9p7+dSq/y
 github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg=
 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/objx v0.2.0 h1:Hbg2NidpLE8veEBkEZTL3CvlkUIVzuU9jDplZO54c48=
 github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
 github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=