From a156a5b6d4848cd45cbd442a737cda19363e42ca Mon Sep 17 00:00:00 2001 From: Joel Lee Date: Sun, 17 Mar 2024 11:53:44 +0800 Subject: [PATCH] feat: support invocation of http hooks --- go.mod | 5 +- go.sum | 2 + internal/api/errorcodes.go | 1 + internal/api/errors.go | 4 + internal/api/hooks.go | 191 ++++++++++++++++++++++++++-- internal/api/hooks_test.go | 144 +++++++++++++++++++++ internal/api/mfa.go | 2 +- internal/api/otp.go | 2 +- internal/api/phone.go | 25 +++- internal/api/phone_test.go | 4 +- internal/api/reauthenticate.go | 2 +- internal/api/resend.go | 4 +- internal/api/signup.go | 2 +- internal/api/token.go | 4 +- internal/api/user.go | 2 +- internal/conf/configuration.go | 13 +- internal/conf/configuration_test.go | 2 + internal/crypto/crypto.go | 28 ++++ internal/hooks/auth_hooks.go | 27 ++++ 19 files changed, 435 insertions(+), 29 deletions(-) create mode 100644 internal/api/hooks_test.go diff --git a/go.mod b/go.mod index 0baa8ffc4..9a7ef3f74 100644 --- a/go.mod +++ b/go.mod @@ -70,6 +70,7 @@ require ( github.com/fatih/structs v1.1.0 github.com/gobuffalo/pop/v6 v6.1.1 github.com/jackc/pgx/v4 v4.18.2 + github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869 github.com/supabase/mailme v0.0.0-20230628061017-01f68480c747 github.com/xeipuuv/gojsonschema v1.2.0 @@ -146,4 +147,6 @@ require ( gopkg.in/yaml.v3 v3.0.1 // indirect ) -go 1.21 +go 1.21.0 + +toolchain go1.21.6 diff --git a/go.sum b/go.sum index 8783b5991..6b5c469ea 100644 --- a/go.sum +++ b/go.sum @@ -469,6 +469,8 @@ github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUq github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0= +github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 h1:HTsFo0buahHfjuVUTPDdJRBkfjExkRM1LUBy6crQ7lc= +github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721/go.mod h1:L1MQhA6x4dn9r007T033lsaZMv9EmBAdXyU/+EF40fo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= diff --git a/internal/api/errorcodes.go b/internal/api/errorcodes.go index 45dec0dd7..8af111bde 100644 --- a/internal/api/errorcodes.go +++ b/internal/api/errorcodes.go @@ -74,4 +74,5 @@ const ( ErrorCodeOverSMSSendRateLimit ErrorCode = "over_sms_send_rate_limit" ErrorBadCodeVerifier ErrorCode = "bad_code_verifier" ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled" + ErrorHookTimeout ErrorCode = "hook_timeout" ) diff --git a/internal/api/errors.go b/internal/api/errors.go index cc6ba877b..216c9944e 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -96,6 +96,10 @@ func conflictError(fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusConflict, ErrorCodeConflict, fmtString, args...) } +func gatewayTimeoutError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError { + return httpError(http.StatusGatewayTimeout, errorCode, fmtString, args...) +} + // HTTPError is an error with a message and an HTTP status code. type HTTPError struct { HTTPStatus int `json:"code"` // do not rename the JSON tags! diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 5368339d8..256667db0 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -1,17 +1,36 @@ package api import ( + "bytes" "context" "encoding/json" "fmt" + "io" + "net" "net/http" + "net/http/httptrace" + "strings" + "time" + "github.com/gofrs/uuid" + "github.com/supabase/auth/internal/observability" + + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/crypto" + + "github.com/sirupsen/logrus" "github.com/supabase/auth/internal/hooks" "github.com/supabase/auth/internal/storage" ) -func (a *API) runHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) { +const ( + DefaultHTTPHookTimeout = 5 * time.Second + DefaultHTTPHookRetries = 3 + HTTPHookBackoffDuration = 2 * time.Second +) + +func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) { db := a.db.WithContext(ctx) request, err := json.Marshal(input) @@ -55,12 +74,168 @@ func (a *API) runHook(ctx context.Context, tx *storage.Connection, name string, return response, nil } -// invokeHook invokes the hook code. tx can be nil, in which case a new +func readBodyWithLimit(rsp *http.Response) ([]byte, error) { + defer rsp.Body.Close() + + const limit = 20 * 1024 // 20KB + limitedReader := io.LimitedReader{R: rsp.Body, N: limit} + + body, err := io.ReadAll(&limitedReader) + if err != nil { + return nil, err + } + + if limitedReader.N <= 0 { + // Attempt to read one more byte to check if we're exactly at the limit or over + _, err := rsp.Body.Read(make([]byte, 1)) + if err == nil { + // If we could read more, then the payload was too large + return nil, fmt.Errorf("payload too large") + } + } + + return body, nil +} + +func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) { + client := http.Client{ + Timeout: DefaultHTTPHookTimeout, + } + log := observability.GetLogEntry(r) + requestURL := hookConfig.URI + hookLog := log.WithFields(logrus.Fields{ + "component": "auth_hook", + "url": requestURL, + }) + + inputPayload, err := json.Marshal(input) + if err != nil { + return nil, err + } + start := time.Now() + for i := 0; i < DefaultHTTPHookRetries; i++ { + hookLog.Infof("invocation attempt: %d", i) + if time.Since(start) > time.Duration(i+1)*DefaultHTTPHookTimeout { + return []byte{}, gatewayTimeoutError(ErrorHookTimeout, "failed to reach hook within timeout") + } + msgID := uuid.Must(uuid.NewV4()) + currentTime := time.Now() + signatureList, err := crypto.GenerateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload) + if err != nil { + return nil, err + } + + req, err := http.NewRequest(http.MethodPost, requestURL, bytes.NewBuffer(inputPayload)) + if err != nil { + return nil, internalServerError("Failed to make request object").WithInternalError(err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("webhook-id", msgID.String()) + req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix())) + req.Header.Set("webhook-signature", strings.Join(signatureList, ", ")) + + watcher, req := watchForConnection(req) + rsp, err := client.Do(req) + + if err != nil { + if terr, ok := err.(net.Error); ok && terr.Timeout() { + hookLog.Errorf("Request timed out for attempt %d with err %s", i, err) + time.Sleep(HTTPHookBackoffDuration) + continue + } else if !watcher.gotConn && i < DefaultHTTPHookRetries-1 { + hookLog.Errorf("Failed to establish a connection on attempt %d with err %s", i, err) + time.Sleep(HTTPHookBackoffDuration) + continue + } else if i == DefaultHTTPHookRetries-1 { + return nil, gatewayTimeoutError(ErrorHookTimeout, "Failed to reach hook within allotted interval") + + } else { + return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err) + } + } + + switch rsp.StatusCode { + case http.StatusOK, http.StatusNoContent, http.StatusAccepted: + if rsp.Body == nil { + return nil, nil + } + body, err := readBodyWithLimit(rsp) + if err != nil { + return nil, err + } + return body, nil + case http.StatusTooManyRequests, http.StatusServiceUnavailable: + retryAfterHeader := rsp.Header.Get("retry-after") + // Check for truthy values to allow for flexibility to swtich to time duration + if retryAfterHeader != "" { + continue + } + return []byte{}, internalServerError("Service currently unavailable") + case http.StatusBadRequest: + return nil, badRequestError(ErrorCodeValidationFailed, "Invalid payload sent to hook") + case http.StatusUnauthorized: + return []byte{}, httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "Hook requires authorizaition token") + default: + return []byte{}, internalServerError("Error executing Hook") + } + } + return nil, internalServerError("error executing hook") +} + +func watchForConnection(req *http.Request) (*connectionWatcher, *http.Request) { + w := new(connectionWatcher) + t := &httptrace.ClientTrace{ + GotConn: w.GotConn, + } + + req = req.WithContext(httptrace.WithClientTrace(req.Context(), t)) + return w, req +} + +type connectionWatcher struct { + gotConn bool +} + +func (c *connectionWatcher) GotConn(_ httptrace.GotConnInfo) { + c.gotConn = true +} + +func (a *API) invokeHTTPHook(r *http.Request, input, output any, hookURI string) error { + switch input.(type) { + case *hooks.CustomSMSProviderInput: + hookOutput, ok := output.(*hooks.CustomSMSProviderOutput) + if !ok { + panic("output should be *hooks.CustomSMSProviderOutput") + } + var response []byte + var err error + + if response, err = a.runHTTPHook(r, a.config.Hook.CustomSMSProvider, input, output); err != nil { + return internalServerError("Error invoking custom SMS provider hook.").WithInternalError(err) + } + if err != nil { + return err + } + + if err := json.Unmarshal(response, hookOutput); err != nil { + return internalServerError("Error unmarshaling custom SMS provider hook output.").WithInternalError(err) + } + fmt.Printf("%v", hookOutput) + + default: + panic("unknown HTTP hook type") + } + return nil +} + +// invokePostgresHook invokes the hook code. tx can be nil, in which case a new // transaction is opened. If calling invokeHook within a transaction, always -// pass the current transaciton, as pool-exhaustion deadlocks are very easy to +// pass the current transaction, as pool-exhaustion deadlocks are very easy to // trigger. -func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, output any) error { +func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, input, output any, hookURI string) error { config := a.config + // Switch based on hook type switch input.(type) { case *hooks.MFAVerificationAttemptInput: hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput) @@ -68,7 +243,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out panic("output should be *hooks.MFAVerificationAttemptOutput") } - if _, err := a.runHook(ctx, tx, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil { + if _, err := a.runPostgresHook(ctx, conn, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil { return internalServerError("Error invoking MFA verification hook.").WithInternalError(err) } @@ -94,7 +269,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out panic("output should be *hooks.PasswordVerificationAttemptOutput") } - if _, err := a.runHook(ctx, tx, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil { + if _, err := a.runPostgresHook(ctx, conn, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil { return internalServerError("Error invoking password verification hook.").WithInternalError(err) } @@ -120,7 +295,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out panic("output should be *hooks.CustomAccessTokenOutput") } - if _, err := a.runHook(ctx, tx, config.Hook.CustomAccessToken.HookName, input, output); err != nil { + if _, err := a.runPostgresHook(ctx, conn, config.Hook.CustomAccessToken.HookName, input, output); err != nil { return internalServerError("Error invoking access token hook.").WithInternalError(err) } @@ -155,6 +330,6 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out return nil default: - panic("unknown hook input type") + panic("unknown Postgres hook input type") } } diff --git a/internal/api/hooks_test.go b/internal/api/hooks_test.go new file mode 100644 index 000000000..38ed6d49d --- /dev/null +++ b/internal/api/hooks_test.go @@ -0,0 +1,144 @@ +package api + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/gofrs/uuid" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/hooks" + + "gopkg.in/h2non/gock.v1" +) + +var handleApiRequest func(*http.Request) (*http.Response, error) + +type HooksTestSuite struct { + suite.Suite + API *API + Config *conf.GlobalConfiguration +} + +type MockHttpClient struct { + mock.Mock +} + +func (m *MockHttpClient) Do(req *http.Request) (*http.Response, error) { + return handleApiRequest(req) +} + +func TestHooks(t *testing.T) { + api, config, err := setupAPIForTest() + require.NoError(t, err) + + ts := &HooksTestSuite{ + API: api, + Config: config, + } + defer api.db.Close() + + suite.Run(t, ts) +} + +func (ts *HooksTestSuite) TestRunHTTPHook() { + defer gock.OffAll() + + input := hooks.CustomSMSProviderInput{ + UserID: uuid.Must(uuid.NewV4()), + Phone: "1234567890", + OTP: "123456", + } + successOutput := hooks.CustomSMSProviderOutput{Success: true} + testURL := "http://localhost:54321/functions/v1/custom-sms-sender" + ts.Config.Hook.CustomSMSProvider.URI = testURL + + testCases := []struct { + description string + mockResponse interface{} + status int + expectError bool + }{ + { + description: "Successful Post request with delay", + mockResponse: successOutput, + status: http.StatusOK, + expectError: false, + }, + { + description: "Too many requests without retry header should not retry", + status: http.StatusGatewayTimeout, + expectError: true, + }, + } + + for _, tc := range testCases { + ts.Run(tc.description, func() { + gock.New(ts.Config.Hook.CustomSMSProvider.URI). + Post("/"). + MatchType("json"). + Reply(tc.status). + JSON(tc.mockResponse) + + var output hooks.CustomSMSProviderOutput + req, _ := http.NewRequest("POST", ts.Config.Hook.CustomSMSProvider.URI, nil) + body, err := ts.API.runHTTPHook(req, ts.Config.Hook.CustomSMSProvider, &input, &output) + + if !tc.expectError { + require.NoError(ts.T(), err) + if body != nil { + require.NoError(ts.T(), json.Unmarshal(body, &output)) + require.True(ts.T(), output.Success) + } + } else { + require.Error(ts.T(), err) + } + require.True(ts.T(), gock.IsDone()) + }) + } +} + +func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() { + defer gock.OffAll() + + input := hooks.CustomSMSProviderInput{ + UserID: uuid.Must(uuid.NewV4()), + Phone: "1234567890", + OTP: "123456", + } + successOutput := hooks.CustomSMSProviderOutput{Success: true} + testURL := "http://localhost:54321/functions/v1/custom-sms-sender" + ts.Config.Hook.CustomSMSProvider.URI = testURL + + gock.New(testURL). + Post("/"). + MatchType("json"). + Reply(http.StatusTooManyRequests). + SetHeader("retry-after", "true") + + // Simulate an additional response for the retry attempt + gock.New(testURL). + Post("/"). + MatchType("json"). + Reply(http.StatusOK). + JSON(successOutput) + + var output hooks.CustomSMSProviderOutput + + // Simulate the original HTTP request which triggered the hook + req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) + require.NoError(ts.T(), err) + + body, err := ts.API.runHTTPHook(req, ts.Config.Hook.CustomSMSProvider, &input, &output) + require.NoError(ts.T(), err) + + err = json.Unmarshal(body, &output) + require.NoError(ts.T(), err, "Unmarshal should not fail") + require.True(ts.T(), output.Success, "Expected success on retry") + + // Ensure that all expected HTTP interactions (mocks) have been called + require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called including retry") +} diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 3919cb781..b2d429dce 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -243,7 +243,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { output := hooks.MFAVerificationAttemptOutput{} - err := a.invokeHook(ctx, nil, &input, &output) + err := a.invokePostgresHook(ctx, nil, &input, &output, config.Hook.MFAVerificationAttempt.URI) if err != nil { return err } diff --git a/internal/api/otp.go b/internal/api/otp.go index 99b7bae32..d0e3d6f18 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -195,7 +195,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Unable to get SMS provider").WithInternalError(err) } - mID, serr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) + mID, serr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) if serr != nil { return badRequestError(ErrorCodeSMSSendFailed, "Error sending sms OTP: %v", serr).WithInternalError(serr) } diff --git a/internal/api/phone.go b/internal/api/phone.go index f85caa6fd..4690ed5a3 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -2,6 +2,8 @@ package api import ( "bytes" + "github.com/supabase/auth/internal/hooks" + "net/http" "regexp" "strings" "text/template" @@ -40,7 +42,7 @@ func formatPhoneNumber(phone string) string { } // sendPhoneConfirmation sends an otp to the user's phone number -func (a *API) sendPhoneConfirmation(tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) (string, error) { +func (a *API) sendPhoneConfirmation(r *http.Request, tx *storage.Connection, user *models.User, phone, otpType string, smsProvider sms_provider.SmsProvider, channel string) (string, error) { config := a.config var token *string @@ -91,10 +93,23 @@ func (a *API) sendPhoneConfirmation(tx *storage.Connection, user *models.User, p if err != nil { return "", err } - - messageID, err = smsProvider.SendMessage(phone, message, channel, otp) - if err != nil { - return messageID, err + if config.Hook.CustomSMSProvider.Enabled { + input := hooks.CustomSMSProviderInput{ + UserID: user.ID, + Phone: user.Phone.String(), + OTP: otp, + } + output := hooks.CustomSMSProviderOutput{} + err := a.invokeHTTPHook(r, &input, &output, config.Hook.CustomSMSProvider.URI) + if err != nil { + return "", err + } + } else { + + messageID, err = smsProvider.SendMessage(phone, message, channel, otp) + if err != nil { + return messageID, err + } } } diff --git a/internal/api/phone_test.go b/internal/api/phone_test.go index 09810e288..d2b9bb9e6 100644 --- a/internal/api/phone_test.go +++ b/internal/api/phone_test.go @@ -72,6 +72,8 @@ func (ts *PhoneTestSuite) TestFormatPhoneNumber() { func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) { u, err := models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) require.NoError(ts.T(), err) + req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil) + require.NoError(ts.T(), err) cases := []struct { desc string otpType string @@ -111,7 +113,7 @@ func doTestSendPhoneConfirmation(ts *PhoneTestSuite, useTestOTP bool) { ts.Run(c.desc, func() { provider := &TestSmsProvider{} - _, err = ts.API.sendPhoneConfirmation(ts.API.db, u, "123456789", c.otpType, provider, sms_provider.SMSProvider) + _, err = ts.API.sendPhoneConfirmation(req, ts.API.db, u, "123456789", c.otpType, provider, sms_provider.SMSProvider) require.Equal(ts.T(), c.expected, err) u, err = models.FindUserByPhoneAndAudience(ts.API.db, "123456789", ts.Config.JWT.Aud) require.NoError(ts.T(), err) diff --git a/internal/api/reauthenticate.go b/internal/api/reauthenticate.go index 84b080070..8143320aa 100644 --- a/internal/api/reauthenticate.go +++ b/internal/api/reauthenticate.go @@ -49,7 +49,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Failed to get SMS provider").WithInternalError(terr) } - mID, err := a.sendPhoneConfirmation(tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) + mID, err := a.sendPhoneConfirmation(r, tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) if err != nil { return err } diff --git a/internal/api/resend.go b/internal/api/resend.go index fdad38c43..0ef80008d 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -134,7 +134,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { if terr != nil { return terr } - mID, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider) + mID, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, sms_provider.SMSProvider) if terr != nil { return terr } @@ -146,7 +146,7 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { if terr != nil { return terr } - mID, terr := a.sendPhoneConfirmation(tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider) + mID, terr := a.sendPhoneConfirmation(r, tx, user, user.PhoneChange, phoneChangeVerification, smsProvider, sms_provider.SMSProvider) if terr != nil { return terr } diff --git a/internal/api/signup.go b/internal/api/signup.go index 5c7e588b8..e0517f6bc 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -278,7 +278,7 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Unable to get SMS provider").WithInternalError(terr) } - if _, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil { + if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel); terr != nil { return unprocessableEntityError(ErrorCodeSMSSendFailed, "Error sending confirmation sms: %v", terr).WithInternalError(terr) } } diff --git a/internal/api/token.go b/internal/api/token.go index df0292711..cee79be11 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -164,7 +164,7 @@ func (a *API) ResourceOwnerPasswordGrant(ctx context.Context, w http.ResponseWri Valid: isValidPassword, } output := hooks.PasswordVerificationAttemptOutput{} - err := a.invokeHook(ctx, nil, &input, &output) + err := a.invokePostgresHook(ctx, nil, &input, &output, config.Hook.PasswordVerificationAttempt.URI) if err != nil { return err } @@ -339,7 +339,7 @@ func (a *API) generateAccessToken(ctx context.Context, tx *storage.Connection, u output := hooks.CustomAccessTokenOutput{} - err := a.invokeHook(ctx, tx, &input, &output) + err := a.invokePostgresHook(ctx, tx, &input, &output, config.Hook.CustomAccessToken.URI) if err != nil { return "", 0, err } diff --git a/internal/api/user.go b/internal/api/user.go index 9fe0dcef8..a9b3a2b89 100644 --- a/internal/api/user.go +++ b/internal/api/user.go @@ -227,7 +227,7 @@ func (a *API) UserUpdate(w http.ResponseWriter, r *http.Request) error { if terr != nil { return internalServerError("Error finding SMS provider").WithInternalError(terr) } - if _, terr := a.sendPhoneConfirmation(tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil { + if _, terr := a.sendPhoneConfirmation(r, tx, user, params.Phone, phoneChangeVerification, smsProvider, params.Channel); terr != nil { return internalServerError("Error sending phone change otp").WithInternalError(terr) } } diff --git a/internal/conf/configuration.go b/internal/conf/configuration.go index 31fb8b22d..74167dea2 100644 --- a/internal/conf/configuration.go +++ b/internal/conf/configuration.go @@ -501,8 +501,14 @@ func (e *ExtensibilityPointConfiguration) ValidateExtensibilityPoint() error { switch strings.ToLower(u.Scheme) { case "pg-functions": return validatePostgresPath(u) + case "http": + hostname := u.Hostname() + if hostname == "localhost" || hostname == "127.0.0.1" || hostname == "::1" { + return validateHTTPHookSecrets(e.HTTPHookSecrets) + } + return fmt.Errorf("only localhost, 127.0.0.1, and ::1 are supported with http") case "https": - return validateHTTPSHookSecrets(e.HTTPHookSecrets) + return validateHTTPHookSecrets(e.HTTPHookSecrets) default: return fmt.Errorf("only postgres hooks and HTTPS functions are supported at the moment") } @@ -530,7 +536,7 @@ func isValidSecretFormat(secret string) bool { return symmetricSecretFormat.MatchString(secret) || asymmetricSecretFormat.MatchString(secret) } -func validateHTTPSHookSecrets(secrets []string) error { +func validateHTTPHookSecrets(secrets []string) error { for _, secret := range secrets { if !isValidSecretFormat(secret) { return fmt.Errorf("invalid secret format") @@ -540,9 +546,6 @@ func validateHTTPSHookSecrets(secrets []string) error { } func (e *ExtensibilityPointConfiguration) PopulateExtensibilityPoint() error { - if err := e.ValidateExtensibilityPoint(); err != nil { - return err - } u, err := url.Parse(e.URI) if err != nil { return err diff --git a/internal/conf/configuration_test.go b/internal/conf/configuration_test.go index f881857bf..4d6eab003 100644 --- a/internal/conf/configuration_test.go +++ b/internal/conf/configuration_test.go @@ -161,8 +161,10 @@ func TestValidateExtensibilityPointURI(t *testing.T) { {desc: "Valid Postgres URI", uri: "pg-functions://postgres/auth/verification_hook_reject", expectError: false}, {desc: "Another Valid URI", uri: "pg-functions://postgres/user_management/add_user", expectError: false}, {desc: "Another Valid URI", uri: "pg-functions://postgres/MySpeCial/FUNCTION_THAT_YELLS_AT_YOU", expectError: false}, + {desc: "Valid HTTP URI", uri: "http://localhost/functions/v1/custom-sms-sender", expectError: false}, // Negative test cases + {desc: "Invalid HTTP URI", uri: "http://asdfgggg.website.co/functions/v1/custom-sms-sender", expectError: true}, {desc: "Invalid HTTPS URI (HTTP)", uri: "http://asdfgggqqwwerty.supabase.co/functions/v1/custom-sms-sender", expectError: true}, {desc: "Invalid Schema Name", uri: "pg-functions://postgres/123auth/verification_hook_reject", expectError: true}, {desc: "Invalid Function Name", uri: "pg-functions://postgres/auth/123verification_hook_reject", expectError: true}, diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index e8063ab1e..590d1ba4d 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -9,6 +9,11 @@ import ( "math" "math/big" "strconv" + "strings" + "time" + + "github.com/gofrs/uuid" + standardwebhooks "github.com/standard-webhooks/standard-webhooks/libraries/go" "github.com/pkg/errors" ) @@ -41,3 +46,26 @@ func GenerateOtp(digits int) (string, error) { func GenerateTokenHash(emailOrPhone, otp string) string { return fmt.Sprintf("%x", sha256.Sum224([]byte(emailOrPhone+otp))) } + +func GenerateSignatures(secrets []string, msgID uuid.UUID, currentTime time.Time, inputPayload []byte) ([]string, error) { + SymmetricSignaturePrefix := "v1," + // TODO(joel): Handle asymmetric case once library has been upgraded + var signatureList []string + for _, secret := range secrets { + if strings.HasPrefix(secret, SymmetricSignaturePrefix) { + trimmedSecret := strings.TrimPrefix(secret, SymmetricSignaturePrefix) + wh, err := standardwebhooks.NewWebhook(trimmedSecret) + if err != nil { + return nil, err + } + signature, err := wh.Sign(msgID.String(), currentTime, inputPayload) + if err != nil { + return nil, err + } + signatureList = append(signatureList, signature) + } else { + return nil, errors.New("invalid signature format") + } + } + return signatureList, nil +} diff --git a/internal/hooks/auth_hooks.go b/internal/hooks/auth_hooks.go index bd3163085..39a226ce7 100644 --- a/internal/hooks/auth_hooks.go +++ b/internal/hooks/auth_hooks.go @@ -22,6 +22,10 @@ const ( HookRejection = "reject" ) +type HTTPHookInput interface { + IsHTTPHook() +} + type HookOutput interface { IsError() bool Error() string @@ -135,6 +139,17 @@ type CustomAccessTokenOutput struct { HookError AuthHookError `json:"error,omitempty"` } +type CustomSMSProviderInput struct { + UserID uuid.UUID `json:"user_id"` + Phone string `json:"phone"` + OTP string `json:"otp"` +} + +type CustomSMSProviderOutput struct { + Success bool `json:"success"` + HookError AuthHookError `json:"error,omitempty"` +} + func (mf *MFAVerificationAttemptOutput) IsError() bool { return mf.HookError.Message != "" } @@ -159,6 +174,18 @@ func (ca *CustomAccessTokenOutput) Error() string { return ca.HookError.Message } +func (cs *CustomSMSProviderOutput) IsError() bool { + return cs.HookError.Message != "" +} + +func (cs *CustomSMSProviderOutput) Error() string { + return cs.HookError.Message +} + +func (cs *CustomSMSProviderOutput) IsHTTPHook() bool { + return true +} + type AuthHookError struct { HTTPCode int `json:"http_code,omitempty"` Message string `json:"message,omitempty"`