Skip to content

Commit d4f5556

Browse files
cstocktonChris Stockton
authored andcommitted
feat: hooks round 2 - remove indirection and simplify error handling (#2025)
## Hooks Round 2 This PR will contain a series of commits preparing for the implementation of before & after user created hooks. It takes the feedback from #2012 into consideration. ### Summary Remove indirection and simplify error handling: * update pkg `internal/api` to: * uses `internal/hooks/v0hooks.Manager` instead of `internal/hooks/hooks.Manager` [aec5995](aec5995) * remove pkg `internal/hooks/hooks.Manager` [062da5d](062da5d) * add pkg `internal/hooks/hookserrors` [7e80afc](7e80afc) * use pkg `internal/hooks/hookserrors` in `internal/hooks/v0hooks` [57744e8](57744e8) * update pkg `internal/hooks/v0hooks` with an `Enabled` method [16cc4c9](16cc4c9) ### Depends on [feat: hooks round 1](#2023) - prepare package structure * Renamed pkg `internal/hooks/v0hooks/v0http` -> `internal/hooks/hookshttp` [8a398ab](8a398ab) * Renamed pkg `internal/hooks/v0hooks/v0pgfunc` -> `internal/hooks/hookspgfunc` [8a398ab](8a398ab) * Use pkg `internal/e2e` for test setup in: * pkg `internal/hooks/hookspgfunc` [4d60288](4d60288) * pkg `internal/hooks/v0hooks` [4a7432b](4a7432b) --------- Co-authored-by: Chris Stockton <chris.stockton@supabase.io>
1 parent 940522d commit d4f5556

File tree

22 files changed

+1487
-625
lines changed

22 files changed

+1487
-625
lines changed

internal/api/api.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@ import (
1010
"github.com/sirupsen/logrus"
1111
"github.com/supabase/auth/internal/api/apierrors"
1212
"github.com/supabase/auth/internal/conf"
13-
"github.com/supabase/auth/internal/hooks"
13+
"github.com/supabase/auth/internal/hooks/hookshttp"
14+
"github.com/supabase/auth/internal/hooks/hookspgfunc"
15+
"github.com/supabase/auth/internal/hooks/v0hooks"
1416
"github.com/supabase/auth/internal/mailer"
1517
"github.com/supabase/auth/internal/models"
1618
"github.com/supabase/auth/internal/observability"
@@ -33,7 +35,7 @@ type API struct {
3335
config *conf.GlobalConfiguration
3436
version string
3537

36-
hooksMgr *hooks.Manager
38+
hooksMgr *v0hooks.Manager
3739
hibpClient *hibp.PwnedClient
3840

3941
// overrideTime can be used to override the clock used by handlers. Should only be used in tests!
@@ -87,7 +89,9 @@ func NewAPIWithVersion(globalConfig *conf.GlobalConfiguration, db *storage.Conne
8789
api.limiterOpts = NewLimiterOptions(globalConfig)
8890
}
8991
if api.hooksMgr == nil {
90-
api.hooksMgr = hooks.NewManager(db, globalConfig)
92+
httpDr := hookshttp.New()
93+
pgfuncDr := hookspgfunc.New(db)
94+
api.hooksMgr = v0hooks.NewManager(globalConfig, httpDr, pgfuncDr)
9195
}
9296
if api.config.Password.HIBP.Enabled {
9397
httpClient := &http.Client{

internal/api/hooks_test.go

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package api
22

33
import (
4-
"encoding/json"
54
"net/http"
65
"testing"
76

@@ -12,6 +11,7 @@ import (
1211
"github.com/stretchr/testify/require"
1312
"github.com/stretchr/testify/suite"
1413
"github.com/supabase/auth/internal/conf"
14+
"github.com/supabase/auth/internal/hooks/hookserrors"
1515
"github.com/supabase/auth/internal/hooks/v0hooks"
1616
"github.com/supabase/auth/internal/models"
1717
"github.com/supabase/auth/internal/storage"
@@ -70,20 +70,20 @@ func (ts *HooksTestSuite) TestRunHTTPHook() {
7070
testURL := "http://localhost:54321/functions/v1/custom-sms-sender"
7171
ts.Config.Hook.SendSMS.URI = testURL
7272

73-
unsuccessfulResponse := v0hooks.AuthHookError{
73+
unsuccessfulResponse := hookserrors.Error{
7474
HTTPCode: http.StatusUnprocessableEntity,
7575
Message: "test error",
7676
}
7777

7878
testCases := []struct {
7979
description string
8080
expectError bool
81-
mockResponse v0hooks.AuthHookError
81+
mockResponse hookserrors.Error
8282
}{
8383
{
8484
description: "Hook returns success",
8585
expectError: false,
86-
mockResponse: v0hooks.AuthHookError{},
86+
mockResponse: hookserrors.Error{},
8787
},
8888
{
8989
description: "Hook returns error",
@@ -102,23 +102,22 @@ func (ts *HooksTestSuite) TestRunHTTPHook() {
102102
Post("/").
103103
MatchType("json").
104104
Reply(http.StatusUnprocessableEntity).
105-
JSON(v0hooks.SendSMSOutput{HookError: unsuccessfulResponse})
105+
JSON(struct {
106+
Error *hookserrors.Error `json:"error,omitempty"`
107+
}{Error: &unsuccessfulResponse})
106108

107109
for _, tc := range testCases {
108110
ts.Run(tc.description, func() {
109111
req, _ := http.NewRequest("POST", ts.Config.Hook.SendSMS.URI, nil)
110-
body, err := ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input)
112+
113+
var output v0hooks.SendSMSOutput
114+
err := ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output)
111115

112116
if !tc.expectError {
113117
require.NoError(ts.T(), err)
114118
} else {
115119
require.Error(ts.T(), err)
116-
if body != nil {
117-
var output v0hooks.SendSMSOutput
118-
require.NoError(ts.T(), json.Unmarshal(body, &output))
119-
require.Equal(ts.T(), unsuccessfulResponse.HTTPCode, output.HookError.HTTPCode)
120-
require.Equal(ts.T(), unsuccessfulResponse.Message, output.HookError.Message)
121-
}
120+
require.Equal(ts.T(), output, v0hooks.SendSMSOutput{})
122121
}
123122
})
124123
}
@@ -154,12 +153,9 @@ func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() {
154153
req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil)
155154
require.NoError(ts.T(), err)
156155

157-
body, err := ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input)
158-
require.NoError(ts.T(), err)
159-
160156
var output v0hooks.SendSMSOutput
161-
err = json.Unmarshal(body, &output)
162-
require.NoError(ts.T(), err, "Unmarshal should not fail")
157+
err = ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output)
158+
require.NoError(ts.T(), err)
163159

164160
// Ensure that all expected HTTP interactions (mocks) have been called
165161
require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called including retry")
@@ -186,10 +182,10 @@ func (ts *HooksTestSuite) TestShouldReturnErrorForNonJSONContentType() {
186182
req, err := http.NewRequest("POST", "http://localhost:9999/otp", nil)
187183
require.NoError(ts.T(), err)
188184

189-
_, err = ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input)
185+
var output v0hooks.SendSMSOutput
186+
err = ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output)
190187
require.Error(ts.T(), err, "Expected an error due to wrong content type")
191188
require.Contains(ts.T(), err.Error(), "Invalid JSON response.")
192-
193189
require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called")
194190
}
195191

internal/api/token.go

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ package api
22

33
import (
44
"context"
5+
"fmt"
56
"net/http"
67
"net/url"
78
"strconv"
89
"time"
910

1011
"github.com/gofrs/uuid"
1112
"github.com/golang-jwt/jwt/v5"
13+
"github.com/xeipuuv/gojsonschema"
1214

1315
"github.com/supabase/auth/internal/api/apierrors"
1416
"github.com/supabase/auth/internal/hooks/v0hooks"
@@ -369,14 +371,16 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user
369371
if err != nil {
370372
return "", 0, err
371373
}
374+
if err := validateTokenClaims(output.Claims); err != nil {
375+
return "", 0, err
376+
}
372377
gotrueClaims = jwt.MapClaims(output.Claims)
373378
}
374379

375380
signed, err := signJwt(&config.JWT, gotrueClaims)
376381
if err != nil {
377382
return "", 0, err
378383
}
379-
380384
return signed, expiresAt.Unix(), nil
381385
}
382386

@@ -491,3 +495,86 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection,
491495
User: user,
492496
}, nil
493497
}
498+
499+
var schemaLoader = gojsonschema.NewStringLoader(MinimumViableTokenSchema)
500+
501+
func validateTokenClaims(outputClaims map[string]interface{}) error {
502+
documentLoader := gojsonschema.NewGoLoader(outputClaims)
503+
result, err := gojsonschema.Validate(schemaLoader, documentLoader)
504+
if err != nil {
505+
return err
506+
}
507+
508+
if !result.Valid() {
509+
var errorMessages string
510+
511+
for _, desc := range result.Errors() {
512+
errorMessages += fmt.Sprintf("- %s\n", desc)
513+
fmt.Printf("- %s\n", desc)
514+
}
515+
return fmt.Errorf(
516+
"output claims do not conform to the expected schema: \n%s", errorMessages)
517+
518+
}
519+
520+
return nil
521+
}
522+
523+
// #nosec
524+
const MinimumViableTokenSchema = `{
525+
"$schema": "http://json-schema.org/draft-07/schema#",
526+
"type": "object",
527+
"properties": {
528+
"aud": {
529+
"type": ["string", "array"]
530+
},
531+
"exp": {
532+
"type": "integer"
533+
},
534+
"jti": {
535+
"type": "string"
536+
},
537+
"iat": {
538+
"type": "integer"
539+
},
540+
"iss": {
541+
"type": "string"
542+
},
543+
"nbf": {
544+
"type": "integer"
545+
},
546+
"sub": {
547+
"type": "string"
548+
},
549+
"email": {
550+
"type": "string"
551+
},
552+
"phone": {
553+
"type": "string"
554+
},
555+
"app_metadata": {
556+
"type": "object",
557+
"additionalProperties": true
558+
},
559+
"user_metadata": {
560+
"type": "object",
561+
"additionalProperties": true
562+
},
563+
"role": {
564+
"type": "string"
565+
},
566+
"aal": {
567+
"type": "string"
568+
},
569+
"amr": {
570+
"type": "array",
571+
"items": {
572+
"type": "object"
573+
}
574+
},
575+
"session_id": {
576+
"type": "string"
577+
}
578+
},
579+
"required": ["aud", "exp", "iat", "sub", "email", "phone", "role", "aal", "session_id", "is_anonymous"]
580+
}`

internal/conf/configuration.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,9 @@ type HookConfiguration struct {
640640
CustomAccessToken ExtensibilityPointConfiguration `json:"custom_access_token" split_words:"true"`
641641
SendEmail ExtensibilityPointConfiguration `json:"send_email" split_words:"true"`
642642
SendSMS ExtensibilityPointConfiguration `json:"send_sms" split_words:"true"`
643+
644+
BeforeUserCreated ExtensibilityPointConfiguration `json:"before_user_created" split_words:"true"`
645+
AfterUserCreated ExtensibilityPointConfiguration `json:"after_user_created" split_words:"true"`
643646
}
644647

645648
type HTTPHookSecrets []string
@@ -671,6 +674,8 @@ func (h *HookConfiguration) Validate() error {
671674
h.CustomAccessToken,
672675
h.SendSMS,
673676
h.SendEmail,
677+
h.BeforeUserCreated,
678+
h.AfterUserCreated,
674679
}
675680
for _, point := range points {
676681
if err := point.ValidateExtensibilityPoint(); err != nil {
@@ -888,6 +893,18 @@ func populateGlobal(config *GlobalConfiguration) error {
888893
}
889894
}
890895

896+
if config.Hook.BeforeUserCreated.Enabled {
897+
if err := config.Hook.BeforeUserCreated.PopulateExtensibilityPoint(); err != nil {
898+
return err
899+
}
900+
}
901+
902+
if config.Hook.AfterUserCreated.Enabled {
903+
if err := config.Hook.AfterUserCreated.PopulateExtensibilityPoint(); err != nil {
904+
return err
905+
}
906+
}
907+
891908
if config.SAML.Enabled {
892909
if err := config.SAML.PopulateFields(config.API.ExternalURL); err != nil {
893910
return err

internal/conf/configuration_test.go

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,36 @@ func TestGlobal(t *testing.T) {
176176
os.Setenv("API_EXTERNAL_URL", "http://localhost:9999")
177177
}
178178

179+
{
180+
os.Setenv("API_EXTERNAL_URL", "")
181+
cfg := new(GlobalConfiguration)
182+
cfg.Hook = HookConfiguration{
183+
BeforeUserCreated: ExtensibilityPointConfiguration{
184+
Enabled: true,
185+
URI: "\n",
186+
},
187+
}
188+
189+
err := populateGlobal(cfg)
190+
require.Error(t, err)
191+
os.Setenv("API_EXTERNAL_URL", "http://localhost:9999")
192+
}
193+
194+
{
195+
os.Setenv("API_EXTERNAL_URL", "")
196+
cfg := new(GlobalConfiguration)
197+
cfg.Hook = HookConfiguration{
198+
AfterUserCreated: ExtensibilityPointConfiguration{
199+
Enabled: true,
200+
URI: "\n",
201+
},
202+
}
203+
204+
err := populateGlobal(cfg)
205+
require.Error(t, err)
206+
os.Setenv("API_EXTERNAL_URL", "http://localhost:9999")
207+
}
208+
179209
{
180210
os.Setenv("API_EXTERNAL_URL", "")
181211
cfg := new(GlobalConfiguration)
@@ -490,6 +520,11 @@ func TestValidate(t *testing.T) {
490520
err: `conf: session timebox duration must` +
491521
` be positive when set, was -1`,
492522
},
523+
{
524+
val: &SessionsConfiguration{InactivityTimeout: toPtr(time.Duration(-1))},
525+
err: `conf: session inactivity timeout duration must` +
526+
` be positive when set, was -1ns`,
527+
},
493528
{
494529
val: &SessionsConfiguration{AllowLowAAL: nil},
495530
},
@@ -532,6 +567,17 @@ func TestValidate(t *testing.T) {
532567
err: `conf: mailer validation headers not a map[string][]string format:` +
533568
` invalid character 'i' looking for beginning of value`,
534569
},
570+
{
571+
val: &MailerConfiguration{EmailValidationBlockedMX: "invalid"},
572+
err: `conf: email_validation_blocked_mx`,
573+
},
574+
{
575+
val: &MailerConfiguration{EmailValidationBlockedMX: `["foo.com"]`},
576+
check: func(t *testing.T, v any) {
577+
got := (v.(*MailerConfiguration)).GetEmailValidationBlockedMXRecords()
578+
require.True(t, got["foo.com"])
579+
},
580+
},
535581

536582
{
537583
val: &CaptchaConfiguration{Enabled: false},

internal/e2e/e2e.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,14 @@ var (
1616
configPath string
1717
)
1818

19+
var isTesting func() bool = testing.Testing
20+
1921
func init() {
20-
if testing.Testing() {
22+
initPackage()
23+
}
24+
25+
func initPackage() {
26+
if isTesting() {
2127
_, thisFile, _, _ := runtime.Caller(0)
2228
projectRoot = filepath.Join(filepath.Dir(thisFile), "../..")
2329
configPath = filepath.Join(GetProjectRoot(), "hack", "test.env")

internal/e2e/e2e_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,27 @@ func TestUtils(t *testing.T) {
9494
t.Fatal("exp non-nil err")
9595
}
9696
}()
97+
98+
// block init from main()
99+
func() {
100+
restore := isTesting
101+
defer func() {
102+
isTesting = restore
103+
}()
104+
isTesting = func() bool { return false }
105+
106+
var errStr string
107+
func() {
108+
defer func() {
109+
errStr = recover().(string)
110+
}()
111+
112+
initPackage()
113+
}()
114+
115+
exp := "package e2e may not be used in a main package"
116+
if errStr != exp {
117+
t.Fatalf("exp %v; got %v", exp, errStr)
118+
}
119+
}()
97120
}

0 commit comments

Comments
 (0)