Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow for postgres and http functions on each extensibility point #1528

Merged
merged 3 commits into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 89 additions & 46 deletions internal/api/hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (
"errors"
"fmt"
"io"
"mime"
"net"
"net/http"
"net/url"
"strings"
"time"

Expand All @@ -31,7 +33,7 @@ const (
PayloadLimit = 200 * 1024 // 200KB
)

func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name string, input, output any) ([]byte, error) {
func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
db := a.db.WithContext(ctx)

request, err := json.Marshal(input)
Expand All @@ -46,7 +48,7 @@ func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name
return terr
}

if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", name), request).First(&response); terr != nil {
if terr := tx.RawQuery(fmt.Sprintf("select %s(?);", hookConfig.HookName), request).First(&response); terr != nil {
return terr
}

Expand Down Expand Up @@ -75,7 +77,8 @@ func (a *API) runPostgresHook(ctx context.Context, tx *storage.Connection, name
return response, nil
}

func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
func (a *API) runHTTPHook(r *http.Request, hookConfig conf.ExtensibilityPointConfiguration, input, output any) ([]byte, error) {
ctx := r.Context()
client := http.Client{
Timeout: DefaultHTTPHookTimeout,
}
Expand Down Expand Up @@ -135,6 +138,15 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf.
}

defer rsp.Body.Close()
// Header.Get is case insensitive
contentType := rsp.Header.Get("Content-Type")
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil {
return nil, internalServerError("Invalid Content-Type header")
}
if mediaType != "application/json" {
return nil, internalServerError("Invalid JSON response. Received content-type: " + contentType)
}

switch rsp.StatusCode {
case http.StatusOK, http.StatusNoContent, http.StatusAccepted:
Expand Down Expand Up @@ -172,67 +184,80 @@ func (a *API) runHTTPHook(ctx context.Context, r *http.Request, hookConfig conf.
return nil, nil
}

func (a *API) invokeHTTPHook(ctx context.Context, r *http.Request, input, output any) error {
// invokePostgresHook invokes the hook code. conn can be nil, in which case a new
// transaction is opened. If calling invokeHook within a transaction, always
// pass the current transaction, as pool-exhaustion deadlocks are very easy to
// trigger.
func (a *API) invokeHook(conn *storage.Connection, r *http.Request, input, output any, uri string) error {
var err error
var response []byte
u, err := url.Parse(uri)
if err != nil {
return err
}

switch input.(type) {
case *hooks.SendSMSInput:
hookOutput, ok := output.(*hooks.SendSMSOutput)
if !ok {
panic("output should be *hooks.SendSMSOutput")
}
var response []byte
var err error

if response, err = a.runHTTPHook(ctx, r, a.config.Hook.SendSMS, input, output); err != nil {
return internalServerError("Error invoking Send SMS hook.").WithInternalError(err)
if response, err = a.runHook(r, conn, a.config.Hook.SendSMS, input, output, u.Scheme); err != nil {
return err
}

if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling Send SMS output.").WithInternalError(err)
}
if hookOutput.IsError() {
J0 marked this conversation as resolved.
Show resolved Hide resolved
httpCode := hookOutput.HookError.HTTPCode

if httpCode == 0 {
httpCode = http.StatusInternalServerError
}
httpError := &HTTPError{
HTTPStatus: httpCode,
Message: hookOutput.HookError.Message,
}
return httpError.WithInternalError(&hookOutput.HookError)
}
return nil
case *hooks.SendEmailInput:
hookOutput, ok := output.(*hooks.SendEmailOutput)
if !ok {
panic("output should be *hooks.SendEmailOutput")
}

var response []byte
var err error

if response, err = a.runHTTPHook(ctx, r, a.config.Hook.SendEmail, input, output); err != nil {
return internalServerError("Error invoking Send Email hook.").WithInternalError(err)
}
if err != nil {
if response, err = a.runHook(r, conn, a.config.Hook.SendEmail, input, output, u.Scheme); err != nil {
return err
}

if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling Send Email hook output.").WithInternalError(err)
return internalServerError("Error unmarshaling Send Email output.").WithInternalError(err)
}
if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode

default:
panic("unknown HTTP hook type")
}
return nil
}
if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

// invokePostgresHook invokes the hook code. tx can be nil, in which case a new
// transaction is opened. If calling invokeHook within a transaction, always
// pass the current transaction, as pool-exhaustion deadlocks are very easy to
// trigger.
func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection, input, output any) error {
config := a.config
// Switch based on hook type
switch input.(type) {
httpError := &HTTPError{
HTTPStatus: httpCode,
Message: hookOutput.HookError.Message,
}

return httpError.WithInternalError(&hookOutput.HookError)
}
return nil
case *hooks.MFAVerificationAttemptInput:
hookOutput, ok := output.(*hooks.MFAVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.MFAVerificationAttemptOutput")
}

if _, err := a.runPostgresHook(ctx, conn, config.Hook.MFAVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking MFA verification hook.").WithInternalError(err)
if response, err = a.runHook(r, conn, a.config.Hook.MFAVerificationAttempt, input, output, u.Scheme); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling MFA Verification Attempt output.").WithInternalError(err)
}

if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode

Expand All @@ -247,18 +272,19 @@ func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection,

return httpError.WithInternalError(&hookOutput.HookError)
}

return nil
case *hooks.PasswordVerificationAttemptInput:
hookOutput, ok := output.(*hooks.PasswordVerificationAttemptOutput)
if !ok {
panic("output should be *hooks.PasswordVerificationAttemptOutput")
}

if _, err := a.runPostgresHook(ctx, conn, config.Hook.PasswordVerificationAttempt.HookName, input, output); err != nil {
return internalServerError("Error invoking password verification hook.").WithInternalError(err)
if response, err = a.runHook(r, conn, a.config.Hook.PasswordVerificationAttempt, input, output, u.Scheme); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling Password Verification Attempt output.").WithInternalError(err)
}

if hookOutput.IsError() {
httpCode := hookOutput.HookError.HTTPCode

Expand All @@ -280,9 +306,11 @@ func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection,
if !ok {
panic("output should be *hooks.CustomAccessTokenOutput")
}

if _, err := a.runPostgresHook(ctx, conn, config.Hook.CustomAccessToken.HookName, input, output); err != nil {
return internalServerError("Error invoking access token hook.").WithInternalError(err)
if response, err = a.runHook(r, conn, a.config.Hook.CustomAccessToken, input, output, u.Scheme); err != nil {
return err
}
if err := json.Unmarshal(response, hookOutput); err != nil {
return internalServerError("Error unmarshaling Custom Access Token output.").WithInternalError(err)
}

if hookOutput.IsError() {
Expand All @@ -305,7 +333,6 @@ func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection,
if httpCode == 0 {
httpCode = http.StatusInternalServerError
}

httpError := &HTTPError{
HTTPStatus: httpCode,
Message: err.Error(),
Expand All @@ -314,8 +341,24 @@ func (a *API) invokePostgresHook(ctx context.Context, conn *storage.Connection,
return httpError
}
return nil
}
return nil
}

func (a *API) runHook(r *http.Request, conn *storage.Connection, hookConfig conf.ExtensibilityPointConfiguration, input, output any, scheme string) ([]byte, error) {
ctx := r.Context()
var response []byte
var err error
switch strings.ToLower(scheme) {
case "http", "https":
response, err = a.runHTTPHook(r, hookConfig, input, output)
case "pg-functions":
response, err = a.runPostgresHook(ctx, conn, hookConfig, input, output)
default:
panic("unknown Postgres hook input type")
return nil, fmt.Errorf("unsupported protocol: %v only postgres hooks and HTTPS functions are supported at the moment", scheme)
}
if err != nil {
return nil, internalServerError("Error running hook URI: %v", hookConfig.URI).WithInternalError(err)
}
return response, nil
}
Loading
Loading