Skip to content

Commit

Permalink
fix: allow enabling sms hook without setting up sms provider (#1704)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?
* Resolves issue where the custom SMS hook cannot be used unless a SMS
provider is configured by moving `GetSmsProvider` into
`sendPhoneConfirmation` and only calls it if the hook is not enabled.
* Allows `channel` to be set for MFA (phone) if hook is enabled 

## What is the current behavior?
* It's not possible to set up a hook without adding config for a SMS
provider

## TODO
- [x] Fix broken tests
  • Loading branch information
kangmingtay authored Aug 4, 2024
1 parent 701a779 commit 575e88a
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 124 deletions.
2 changes: 1 addition & 1 deletion internal/api/hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ import (
"net/http"
"testing"

"errors"
"net/http/httptest"

"github.com/pkg/errors"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
Expand Down
25 changes: 10 additions & 15 deletions internal/api/mfa.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,10 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error
return err
}
channel := params.Channel

if channel == "" {
channel = sms_provider.SMSProvider
}
smsProvider, err := sms_provider.GetSmsProvider(*config)
if err != nil {
return internalServerError("Failed to get SMS provider").WithInternalError(err)
}
if !sms_provider.IsValidMessageChannel(channel, config.Sms.Provider) {
if !sms_provider.IsValidMessageChannel(channel, config) {
return badRequestError(ErrorCodeValidationFailed, InvalidChannelError)
}
latestValidChallenge, err := factor.FindLatestUnexpiredChallenge(a.db, config.MFA.ChallengeExpiryDuration)
Expand All @@ -301,20 +296,18 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error
} else if latestValidChallenge != nil && !latestValidChallenge.SentAt.Add(config.MFA.Phone.MaxFrequency).Before(time.Now()) {
return tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(latestValidChallenge.SentAt, config.MFA.Phone.MaxFrequency))
}

otp, err := crypto.GenerateOtp(config.MFA.Phone.OtpLength)
if err != nil {
panic(err)
}
challenge, err := factor.CreatePhoneChallenge(ipAddress, otp, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey)
if err != nil {
return internalServerError("error creating SMS Challenge")
}

message, err := generateSMSFromTemplate(config.MFA.Phone.SMSTemplate, otp)
if err != nil {
return internalServerError("error generating sms template").WithInternalError(err)
}
challenge, err := factor.CreatePhoneChallenge(ipAddress, otp, config.Security.DBEncryption.Encrypt, config.Security.DBEncryption.EncryptionKeyID, config.Security.DBEncryption.EncryptionKey)
if err != nil {
return internalServerError("error creating SMS Challenge")
}
if config.Hook.SendSMS.Enabled {
input := hooks.SendSMSInput{
User: user,
Expand All @@ -329,10 +322,12 @@ func (a *API) challengePhoneFactor(w http.ResponseWriter, r *http.Request) error
return internalServerError("error invoking hook")
}
} else {

// We omit messageID for now, can consider reinstating if there are requests.
_, err := smsProvider.SendMessage(string(factor.Phone), message, channel, otp)
smsProvider, err := sms_provider.GetSmsProvider(*config)
if err != nil {
return internalServerError("Failed to get SMS provider").WithInternalError(err)
}
// We omit messageID for now, can consider reinstating if there are requests.
if _, err = smsProvider.SendMessage(string(factor.Phone), message, channel, otp); err != nil {
return internalServerError("error sending message").WithInternalError(err)
}
}
Expand Down
9 changes: 0 additions & 9 deletions internal/api/mfa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,14 +265,6 @@ func (ts *MFATestSuite) TestChallengeSMSFactor() {
begin
return input;
end; $$ language plpgsql;`).Exec())
// We still need a mock provider for hooks to work right now for backward compatibility
// The WhatsApp channel is only valid when twilio or twilio verify is set.
ts.Config.Sms.Provider = "twilio"
ts.Config.Sms.Twilio = conf.TwilioProviderConfiguration{
AccountSid: "test_account_sid",
AuthToken: "test_auth_token",
MessageServiceSid: "test_message_service_id",
}

phone := "+1234567"
friendlyName := "testchallengesmsfactor"
Expand Down Expand Up @@ -491,7 +483,6 @@ func (ts *MFATestSuite) TestUnenrollVerifiedFactor() {
func (ts *MFATestSuite) TestUnenrollUnverifiedFactor() {
var buffer bytes.Buffer
f := ts.TestUser.Factors[0]
f.Secret = ts.TestOTPKey.Secret()

token := ts.generateAAL1Token(ts.TestUser, &ts.TestSession.ID)
require.NoError(ts.T(), json.NewEncoder(&buffer).Encode(map[string]interface{}{
Expand Down
21 changes: 8 additions & 13 deletions internal/api/otp.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/sethvargo/go-password/password"
"github.com/supabase/auth/internal/api/sms_provider"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/storage"
)
Expand Down Expand Up @@ -45,17 +46,15 @@ func (p *OtpParams) Validate() error {
return nil
}

func (p *SmsParams) Validate(smsProvider string) error {
if p.Phone != "" && !sms_provider.IsValidMessageChannel(p.Channel, smsProvider) {
return badRequestError(ErrorCodeValidationFailed, InvalidChannelError)
}

func (p *SmsParams) Validate(config *conf.GlobalConfiguration) error {
var err error
p.Phone, err = validatePhone(p.Phone)
if err != nil {
return err
}

if !sms_provider.IsValidMessageChannel(p.Channel, config) {
return badRequestError(ErrorCodeValidationFailed, InvalidChannelError)
}
return nil
}

Expand Down Expand Up @@ -119,7 +118,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
params.Channel = sms_provider.SMSProvider
}

if err := params.Validate(config.Sms.Provider); err != nil {
if err := params.Validate(config); err != nil {
return err
}

Expand Down Expand Up @@ -191,13 +190,9 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error {
}); err != nil {
return err
}
smsProvider, terr := sms_provider.GetSmsProvider(*config)
if terr != nil {
return internalServerError("Unable to get SMS provider").WithInternalError(err)
}
mID, serr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel)
mID, serr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, params.Channel)
if serr != nil {
return badRequestError(ErrorCodeSMSSendFailed, "Error sending sms OTP: %v", serr).WithInternalError(serr)
return serr
}
messageID = mID
return nil
Expand Down
36 changes: 20 additions & 16 deletions internal/api/phone.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func formatPhoneNumber(phone string) string {
}

// sendPhoneConfirmation sends an otp to the user's phone number
func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) (string, error) {
func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, channel string) (string, error) {
config := a.config

var token *string
Expand Down Expand Up @@ -71,7 +71,7 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use
// intentionally keeping this before the test OTP, so that the behavior
// of regular and test OTPs is similar
if sentAt != nil && !sentAt.Add(config.Sms.MaxFrequency).Before(time.Now()) {
return "", MaxFrequencyLimitError
return "", tooManyRequestsError(ErrorCodeOverSMSSendRateLimit, generateFrequencyLimitErrorMessage(sentAt, config.Sms.MaxFrequency))
}

now := time.Now()
Expand All @@ -89,14 +89,7 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use
if err != nil {
return "", internalServerError("error generating otp").WithInternalError(err)
}

message, err := generateSMSFromTemplate(config.Sms.SMSTemplate, otp)
if err != nil {
return "", err
}

// Hook should only be called if SMS autoconfirm is disabled
if !config.Sms.Autoconfirm && config.Hook.SendSMS.Enabled {
if config.Hook.SendSMS.Enabled {
input := hooks.SendSMSInput{
User: user,
SMS: hooks.SMS{
Expand All @@ -109,9 +102,17 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use
return "", err
}
} else {
messageID, err = smsProvider.SendMessage(phone, message, channel, otp)
smsProvider, err := sms_provider.GetSmsProvider(*config)
if err != nil {
return "", internalServerError("Unable to get SMS provider").WithInternalError(err)
}
message, err := generateSMSFromTemplate(config.Sms.SMSTemplate, otp)
if err != nil {
return "", internalServerError("error generating sms template").WithInternalError(err)
}
messageID, err := smsProvider.SendMessage(phone, message, channel, otp)
if err != nil {
return messageID, err
return messageID, unprocessableEntityError(ErrorCodeSMSSendFailed, "Error sending %s OTP to provider: %v", otpType, err)
}
}
}
Expand All @@ -131,21 +132,24 @@ func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, use
return messageID, errors.Wrap(err, "Database error updating user for phone")
}

var ottErr error
switch otpType {
case phoneConfirmationOtp:
if err := models.CreateOneTimeToken(tx, user.ID, user.GetPhone(), user.ConfirmationToken, models.ConfirmationToken); err != nil {
return messageID, errors.Wrap(err, "Database error creating confirmation token for phone")
ottErr = errors.Wrap(err, "Database error creating confirmation token for phone")
}
case phoneChangeVerification:
if err := models.CreateOneTimeToken(tx, user.ID, user.PhoneChange, user.PhoneChangeToken, models.PhoneChangeToken); err != nil {
return messageID, errors.Wrap(err, "Database error creating phone change token")
ottErr = errors.Wrap(err, "Database error creating phone change token")
}
case phoneReauthenticationOtp:
if err := models.CreateOneTimeToken(tx, user.ID, user.GetPhone(), user.ReauthenticationToken, models.ReauthenticationToken); err != nil {
return messageID, errors.Wrap(err, "Database error creating reauthentication token for phone")
ottErr = errors.Wrap(err, "Database error creating reauthentication token for phone")
}
}

if ottErr != nil {
return messageID, internalServerError("error creating one time token").WithInternalError(ottErr)
}
return messageID, nil
}

Expand Down
64 changes: 29 additions & 35 deletions internal/api/phone_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) {
for _, c := range cases {
ts.Run(c.desc, func() {
provider := &TestSmsProvider{}
sms_provider.MockProvider = provider

_, err = ts.API.sendPhoneConfirmation(req, ts.API.db, u, "123456789", c.otpType, provider, sms_provider.SMSProvider)
_, err = ts.API.sendPhoneConfirmation(req, ts.API.db, u, "123456789", c.otpType, sms_provider.SMSProvider)
require.Equal(ts.T(), c.expected, err)
u, err = models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud)
require.NoError(ts.T(), err)
Expand Down Expand Up @@ -306,13 +307,13 @@ func (ts *PhoneTestSuite) TestSendSMSHook() {
method: http.MethodPost,
uri: "pg-functions://postgres/auth/send_sms_signup",
hookFunctionSQL: `
create or replace function send_sms_signup(input jsonb)
returns json as $$
begin
insert into job_queue(job_type, payload)
values ('sms_signup', input);
return input;
end; $$ language plpgsql;`,
create or replace function send_sms_signup(input jsonb)
returns json as $$
begin
insert into job_queue(job_type, payload)
values ('sms_signup', input);
return input;
end; $$ language plpgsql;`,
header: "",
body: map[string]string{
"phone": "1234567890",
Expand All @@ -327,13 +328,13 @@ func (ts *PhoneTestSuite) TestSendSMSHook() {
method: http.MethodPost,
uri: "pg-functions://postgres/auth/send_sms_otp",
hookFunctionSQL: `
create or replace function send_sms_otp(input jsonb)
returns json as $$
begin
insert into job_queue(job_type, payload)
values ('sms_signup', input);
return input;
end; $$ language plpgsql;`,
create or replace function send_sms_otp(input jsonb)
returns json as $$
begin
insert into job_queue(job_type, payload)
values ('sms_signup', input);
return input;
end; $$ language plpgsql;`,
header: "",
body: map[string]string{
"phone": "123456789",
Expand All @@ -348,13 +349,13 @@ func (ts *PhoneTestSuite) TestSendSMSHook() {
method: http.MethodPut,
uri: "pg-functions://postgres/auth/send_sms_phone_change",
hookFunctionSQL: `
create or replace function send_sms_phone_change(input jsonb)
returns json as $$
begin
insert into job_queue(job_type, payload)
values ('phone_change', input);
return input;
end; $$ language plpgsql;`,
create or replace function send_sms_phone_change(input jsonb)
returns json as $$
begin
insert into job_queue(job_type, payload)
values ('phone_change', input);
return input;
end; $$ language plpgsql;`,
header: token,
body: map[string]string{
"phone": "111111111",
Expand All @@ -369,11 +370,11 @@ func (ts *PhoneTestSuite) TestSendSMSHook() {
method: http.MethodGet,
uri: "pg-functions://postgres/auth/reauthenticate",
hookFunctionSQL: `
create or replace function reauthenticate(input jsonb)
returns json as $$
begin
return input;
end; $$ language plpgsql;`,
create or replace function reauthenticate(input jsonb)
returns json as $$
begin
return input;
end; $$ language plpgsql;`,
header: "",
body: nil,
expectToken: true,
Expand All @@ -396,7 +397,7 @@ func (ts *PhoneTestSuite) TestSendSMSHook() {
"phone": "123456789",
},
expectToken: false,
expectedCode: http.StatusBadRequest,
expectedCode: http.StatusInternalServerError,
hookFunctionIdentifier: "send_sms_otp_failure(input jsonb)",
},
}
Expand All @@ -409,13 +410,6 @@ func (ts *PhoneTestSuite) TestSendSMSHook() {
ts.Config.Hook.SendSMS.URI = c.uri
// Disable FrequencyLimit to allow back to back sending
ts.Config.Sms.MaxFrequency = 0 * time.Second
// We still need a mock provider for hooks to work right now for backward compatibility
ts.Config.Sms.Provider = "twilio"
ts.Config.Sms.Twilio = conf.TwilioProviderConfiguration{
AccountSid: "test_account_sid",
AuthToken: "test_auth_token",
MessageServiceSid: "test_message_service_id",
}
require.NoError(ts.T(), ts.Config.Hook.SendSMS.PopulateExtensibilityPoint())

require.NoError(t, ts.API.db.RawQuery(c.hookFunctionSQL).Exec())
Expand Down
7 changes: 1 addition & 6 deletions internal/api/reauthenticate.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ const InvalidNonceMessage = "Nonce has expired or is invalid"
func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error {
ctx := r.Context()
db := a.db.WithContext(ctx)
config := a.config

user := getUser(ctx)
email, phone := user.GetEmail(), user.GetPhone()
Expand All @@ -44,11 +43,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error {
if email != "" {
return a.sendReauthenticationOtp(r, tx, user)
} else if phone != "" {
smsProvider, terr := sms_provider.GetSmsProvider(*config)
if terr != nil {
return internalServerError("Failed to get SMS provider").WithInternalError(terr)
}
mID, err := a.sendPhoneConfirmation(r, tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider)
mID, err := a.sendPhoneConfirmation(r, tx, user, phone, phoneReauthenticationOtp, sms_provider.SMSProvider)
if err != nil {
return err
}
Expand Down
12 changes: 2 additions & 10 deletions internal/api/resend.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,15 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error {
if terr := models.NewAuditLogEntry(r, tx, user, models.UserRecoveryRequestedAction, "", nil); terr != nil {
return terr
}
smsProvider, terr := sms_provider.GetSmsProvider(*config)
if terr != nil {
return terr
}
mID, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider)
mID, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, sms_provider.SMSProvider)
if terr != nil {
return terr
}
messageID = mID
case mail.EmailChangeVerification:
return a.sendEmailChange(r, tx, user, user.EmailChange, models.ImplicitFlow)
case phoneChangeVerification:
smsProvider, terr := sms_provider.GetSmsProvider(*config)
if terr != nil {
return terr
}
mID, terr := a.sendPhoneConfirmation(r, tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider)
mID, terr := a.sendPhoneConfirmation(r, tx, user, user.PhoneChange, phoneChangeVerification, sms_provider.SMSProvider)
if terr != nil {
return terr
}
Expand Down
Loading

0 comments on commit 575e88a

Please sign in to comment.