Skip to content

Commit

Permalink
feat: support invocation of http hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
J0 committed Mar 17, 2024
1 parent 948e3b4 commit a156a5b
Show file tree
Hide file tree
Showing 19 changed files with 435 additions and 29 deletions.
5 changes: 4 additions & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ require (
github.com/fatih/structs v1.1.0
github.com/gobuffalo/pop/v6 v6.1.1
github.com/jackc/pgx/v4 v4.18.2
github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721
github.com/supabase/hibp v0.0.0-20231124125943-d225752ae869
github.com/supabase/mailme v0.0.0-20230628061017-01f68480c747
github.com/xeipuuv/gojsonschema v1.2.0
Expand Down Expand Up @@ -146,4 +147,6 @@ require (
gopkg.in/yaml.v3 v3.0.1 // indirect
)

go 1.21
go 1.21.0

toolchain go1.21.6
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,8 @@ github.com/spf13/cobra v1.6.1/go.mod h1:IOw/AERYS7UzyrGinqmz6HLUo219MORXGxhbaJUq
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721 h1:HTsFo0buahHfjuVUTPDdJRBkfjExkRM1LUBy6crQ7lc=
github.com/standard-webhooks/standard-webhooks/libraries v0.0.0-20240303152453-e0e82adf1721/go.mod h1:L1MQhA6x4dn9r007T033lsaZMv9EmBAdXyU/+EF40fo=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
Expand Down
1 change: 1 addition & 0 deletions internal/api/errorcodes.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,5 @@ const (
ErrorCodeOverSMSSendRateLimit ErrorCode = "over_sms_send_rate_limit"
ErrorBadCodeVerifier ErrorCode = "bad_code_verifier"
ErrorCodeAnonymousProviderDisabled ErrorCode = "anonymous_provider_disabled"
ErrorHookTimeout ErrorCode = "hook_timeout"
)
4 changes: 4 additions & 0 deletions internal/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ func conflictError(fmtString string, args ...interface{}) *HTTPError {
return httpError(http.StatusConflict, ErrorCodeConflict, fmtString, args...)
}

func gatewayTimeoutError(errorCode ErrorCode, fmtString string, args ...interface{}) *HTTPError {
return httpError(http.StatusGatewayTimeout, errorCode, fmtString, args...)
}

// HTTPError is an error with a message and an HTTP status code.
type HTTPError struct {
HTTPStatus int `json:"code"` // do not rename the JSON tags!
Expand Down
191 changes: 183 additions & 8 deletions internal/api/hooks.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,36 @@
package api

import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"net/http/httptrace"
"strings"
"time"

"github.com/gofrs/uuid"
"github.com/supabase/auth/internal/observability"

"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/crypto"

"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/hooks"

"github.com/supabase/auth/internal/storage"
)

func (a *API) runHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) {
const (
DefaultHTTPHookTimeout = 5 * time.Second
DefaultHTTPHookRetries = 3
HTTPHookBackoffDuration = 2 * time.Second
)

func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) {
db := a.db.WithContext(ctx)

request, err := json.Marshal(input)
Expand Down Expand Up @@ -55,20 +74,176 @@ func (a *API) runHook(ctx context.Context, tx *storage.Connection, name string,
return response, nil
}

// invokeHook invokes the hook code. tx can be nil, in which case a new
func readBodyWithLimit(rsp *http.Response) ([]byte, error) {
defer rsp.Body.Close()

const limit = 20 * 1024 // 20KB
limitedReader := io.LimitedReader{R: rsp.Body, N: limit}

body, err := io.ReadAll(&limitedReader)
if err != nil {
return nil, err
}

if limitedReader.N <= 0 {
// Attempt to read one more byte to check if we're exactly at the limit or over
_, err := rsp.Body.Read(make([]byte, 1))
if err == nil {
// If we could read more, then the payload was too large
return nil, fmt.Errorf("payload too large")
}
}

return body, nil
}

func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
client := http.Client{
Timeout: DefaultHTTPHookTimeout,
}
log := observability.GetLogEntry(r)
requestURL := hookConfig.URI
hookLog := log.WithFields(logrus.Fields{
"component": "auth_hook",
"url": requestURL,
})

inputPayload, err := json.Marshal(input)
if err != nil {
return nil, err
}
start := time.Now()
for i := 0; i < DefaultHTTPHookRetries; i++ {
hookLog.Infof("invocation attempt: %d", i)
if time.Since(start) > time.Duration(i+1)*DefaultHTTPHookTimeout {
return []byte{}, gatewayTimeoutError(ErrorHookTimeout, "failed to reach hook within timeout")
}
msgID := uuid.Must(uuid.NewV4())
currentTime := time.Now()
signatureList, err := crypto.GenerateSignatures(hookConfig.HTTPHookSecrets, msgID, currentTime, inputPayload)
if err != nil {
return nil, err
}

req, err := http.NewRequest(http.MethodPost, requestURL, bytes.NewBuffer(inputPayload))
if err != nil {
return nil, internalServerError("Failed to make request object").WithInternalError(err)
}

req.Header.Set("Content-Type", "application/json")
req.Header.Set("webhook-id", msgID.String())
req.Header.Set("webhook-timestamp", fmt.Sprintf("%d", currentTime.Unix()))
req.Header.Set("webhook-signature", strings.Join(signatureList, ", "))

watcher, req := watchForConnection(req)
rsp, err := client.Do(req)

if err != nil {
if terr, ok := err.(net.Error); ok && terr.Timeout() {
hookLog.Errorf("Request timed out for attempt %d with err %s", i, err)
time.Sleep(HTTPHookBackoffDuration)
continue
} else if !watcher.gotConn && i < DefaultHTTPHookRetries-1 {
hookLog.Errorf("Failed to establish a connection on attempt %d with err %s", i, err)
time.Sleep(HTTPHookBackoffDuration)
continue
} else if i == DefaultHTTPHookRetries-1 {
return nil, gatewayTimeoutError(ErrorHookTimeout, "Failed to reach hook within allotted interval")

} else {
return nil, internalServerError("Failed to trigger auth hook, error making HTTP request").WithInternalError(err)
}
}

switch rsp.StatusCode {
case http.StatusOK, http.StatusNoContent, http.StatusAccepted:
if rsp.Body == nil {
return nil, nil
}
body, err := readBodyWithLimit(rsp)
if err != nil {
return nil, err
}
return body, nil
case http.StatusTooManyRequests, http.StatusServiceUnavailable:
retryAfterHeader := rsp.Header.Get("retry-after")
// Check for truthy values to allow for flexibility to swtich to time duration
if retryAfterHeader != "" {
continue
}
return []byte{}, internalServerError("Service currently unavailable")
case http.StatusBadRequest:
return nil, badRequestError(ErrorCodeValidationFailed, "Invalid payload sent to hook")
case http.StatusUnauthorized:
return []byte{}, httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "Hook requires authorizaition token")
default:
return []byte{}, internalServerError("Error executing Hook")
}
}
return nil, internalServerError("error executing hook")
}

func watchForConnection(req *http.Request) (*connectionWatcher, *http.Request) {
w := new(connectionWatcher)
t := &httptrace.ClientTrace{
GotConn: w.GotConn,
}

req = req.WithContext(httptrace.WithClientTrace(req.Context(), t))
return w, req
}

type connectionWatcher struct {
gotConn bool
}

func (c *connectionWatcher) GotConn(_ httptrace.GotConnInfo) {
c.gotConn = true
}

func (a *API) invokeHTTPHook(r *http.Request, input, output any, hookURI string) error {
switch input.(type) {
case *hooks.CustomSMSProviderInput:
hookOutput, ok := output.(*hooks.CustomSMSProviderOutput)
if !ok {
panic("output should be *hooks.CustomSMSProviderOutput")
}
var response []byte
var err error

if response, err = a.runHTTPHook(r, a.config.Hook.CustomSMSProvider, input, output); err != nil {
return internalServerError("Error invoking custom SMS provider hook.").WithInternalError(err)
}
if err != nil {
return err
}

if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling custom SMS provider hook output.").WithInternalError(err)
}
fmt.Printf("%v", hookOutput)

default:
panic("unknown HTTP hook type")
}
return nil
}

// invokePostgresHook invokes the hook code. tx can be nil, in which case a new
// transaction is opened. If calling invokeHook within a transaction, always
// pass the current transaciton, as pool-exhaustion deadlocks are very easy to
// pass the current transaction, as pool-exhaustion deadlocks are very easy to
// trigger.
func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, output any) error {
func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, input, output any, hookURI string) error {
config := a.config
// Switch based on hook type
switch input.(type) {
case *hooks.MFAVerificationAttemptInput:
hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.MFAVerificationAttemptOutput")
}

if _, err := a.runHook(ctx, tx, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
if _, err := a.runPostgresHook(ctx, conn, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking MFA verification hook.").WithInternalError(err)
}

Expand All @@ -94,7 +269,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out
panic("output should be *hooks.PasswordVerificationAttemptOutput")
}

if _, err := a.runHook(ctx, tx, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
if _, err := a.runPostgresHook(ctx, conn, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking password verification hook.").WithInternalError(err)
}

Expand All @@ -120,7 +295,7 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out
panic("output should be *hooks.CustomAccessTokenOutput")
}

if _, err := a.runHook(ctx, tx, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
if _, err := a.runPostgresHook(ctx, conn, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
return internalServerError("Error invoking access token hook.").WithInternalError(err)
}

Expand Down Expand Up @@ -155,6 +330,6 @@ func (a *API) invokeHook(ctx context.Context, tx *storage.Connection, input, out
return nil

default:
panic("unknown hook input type")
panic("unknown Postgres hook input type")
}
}
Loading

0 comments on commit a156a5b

Please sign in to comment.