Skip to content

Commit 57744e8

Browse files
author
Chris Stockton
committed
feat: use pkg hookserrors in hookshttp & hookspgfunc
This allows consistent error handling across both the hookshttp and hookspgfunc packages without the need to embed AuthHookError in every single output struct. In addition this required moving `validateTokenClaims` to the API package. This removes the need for a single special case in the dispatching of hooks in pkg `internal/hooks/v0hooks`.
1 parent 7e80afc commit 57744e8

File tree

9 files changed

+214
-373
lines changed

9 files changed

+214
-373
lines changed

internal/api/hooks_test.go

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package api
22

33
import (
4-
"encoding/json"
54
"net/http"
65
"testing"
76

@@ -12,6 +11,7 @@ import (
1211
"github.com/stretchr/testify/require"
1312
"github.com/stretchr/testify/suite"
1413
"github.com/supabase/auth/internal/conf"
14+
"github.com/supabase/auth/internal/hooks/hookserrors"
1515
"github.com/supabase/auth/internal/hooks/v0hooks"
1616
"github.com/supabase/auth/internal/models"
1717
"github.com/supabase/auth/internal/storage"
@@ -70,20 +70,20 @@ func (ts *HooksTestSuite) TestRunHTTPHook() {
7070
testURL := "http://localhost:54321/functions/v1/custom-sms-sender"
7171
ts.Config.Hook.SendSMS.URI = testURL
7272

73-
unsuccessfulResponse := v0hooks.AuthHookError{
73+
unsuccessfulResponse := hookserrors.Error{
7474
HTTPCode: http.StatusUnprocessableEntity,
7575
Message: "test error",
7676
}
7777

7878
testCases := []struct {
7979
description string
8080
expectError bool
81-
mockResponse v0hooks.AuthHookError
81+
mockResponse hookserrors.Error
8282
}{
8383
{
8484
description: "Hook returns success",
8585
expectError: false,
86-
mockResponse: v0hooks.AuthHookError{},
86+
mockResponse: hookserrors.Error{},
8787
},
8888
{
8989
description: "Hook returns error",
@@ -102,23 +102,22 @@ func (ts *HooksTestSuite) TestRunHTTPHook() {
102102
Post("/").
103103
MatchType("json").
104104
Reply(http.StatusUnprocessableEntity).
105-
JSON(v0hooks.SendSMSOutput{HookError: unsuccessfulResponse})
105+
JSON(struct {
106+
Error *hookserrors.Error `json:"error,omitempty"`
107+
}{Error: &unsuccessfulResponse})
106108

107109
for _, tc := range testCases {
108110
ts.Run(tc.description, func() {
109111
req, _ := http.NewRequest("POST", ts.Config.Hook.SendSMS.URI, nil)
110-
body, err := ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input)
112+
113+
var output v0hooks.SendSMSOutput
114+
err := ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output)
111115

112116
if !tc.expectError {
113117
require.NoError(ts.T(), err)
114118
} else {
115119
require.Error(ts.T(), err)
116-
if body != nil {
117-
var output v0hooks.SendSMSOutput
118-
require.NoError(ts.T(), json.Unmarshal(body, &output))
119-
require.Equal(ts.T(), unsuccessfulResponse.HTTPCode, output.HookError.HTTPCode)
120-
require.Equal(ts.T(), unsuccessfulResponse.Message, output.HookError.Message)
121-
}
120+
require.Equal(ts.T(), output, v0hooks.SendSMSOutput{})
122121
}
123122
})
124123
}
@@ -154,12 +153,9 @@ func (ts *HooksTestSuite) TestShouldRetryWithRetryAfterHeader() {
154153
req, err := http.NewRequest("POST", "http://localhost:9998/otp", nil)
155154
require.NoError(ts.T(), err)
156155

157-
body, err := ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input)
158-
require.NoError(ts.T(), err)
159-
160156
var output v0hooks.SendSMSOutput
161-
err = json.Unmarshal(body, &output)
162-
require.NoError(ts.T(), err, "Unmarshal should not fail")
157+
err = ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output)
158+
require.NoError(ts.T(), err)
163159

164160
// Ensure that all expected HTTP interactions (mocks) have been called
165161
require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called including retry")
@@ -186,10 +182,10 @@ func (ts *HooksTestSuite) TestShouldReturnErrorForNonJSONContentType() {
186182
req, err := http.NewRequest("POST", "http://localhost:9999/otp", nil)
187183
require.NoError(ts.T(), err)
188184

189-
_, err = ts.API.hooksMgr.RunHTTPHook(req, ts.Config.Hook.SendSMS, &input)
185+
var output v0hooks.SendSMSOutput
186+
err = ts.API.hooksMgr.InvokeHook(ts.API.db, req, &input, &output)
190187
require.Error(ts.T(), err, "Expected an error due to wrong content type")
191188
require.Contains(ts.T(), err.Error(), "Invalid JSON response.")
192-
193189
require.True(ts.T(), gock.IsDone(), "Expected all mocks to have been called")
194190
}
195191

internal/api/token.go

Lines changed: 88 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@ package api
22

33
import (
44
"context"
5+
"fmt"
56
"net/http"
67
"net/url"
78
"strconv"
89
"time"
910

1011
"github.com/gofrs/uuid"
1112
"github.com/golang-jwt/jwt/v5"
13+
"github.com/xeipuuv/gojsonschema"
1214

1315
"github.com/supabase/auth/internal/api/apierrors"
1416
"github.com/supabase/auth/internal/hooks/v0hooks"
@@ -369,14 +371,16 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user
369371
if err != nil {
370372
return "", 0, err
371373
}
374+
if err := validateTokenClaims(output.Claims); err != nil {
375+
return "", 0, err
376+
}
372377
gotrueClaims = jwt.MapClaims(output.Claims)
373378
}
374379

375380
signed, err := signJwt(&config.JWT, gotrueClaims)
376381
if err != nil {
377382
return "", 0, err
378383
}
379-
380384
return signed, expiresAt.Unix(), nil
381385
}
382386

@@ -491,3 +495,86 @@ func (a *API) updateMFASessionAndClaims(r *http.Request, tx *storage.Connection,
491495
User: user,
492496
}, nil
493497
}
498+
499+
var schemaLoader = gojsonschema.NewStringLoader(MinimumViableTokenSchema)
500+
501+
func validateTokenClaims(outputClaims map[string]interface{}) error {
502+
documentLoader := gojsonschema.NewGoLoader(outputClaims)
503+
result, err := gojsonschema.Validate(schemaLoader, documentLoader)
504+
if err != nil {
505+
return err
506+
}
507+
508+
if !result.Valid() {
509+
var errorMessages string
510+
511+
for _, desc := range result.Errors() {
512+
errorMessages += fmt.Sprintf("- %s\n", desc)
513+
fmt.Printf("- %s\n", desc)
514+
}
515+
return fmt.Errorf(
516+
"output claims do not conform to the expected schema: \n%s", errorMessages)
517+
518+
}
519+
520+
return nil
521+
}
522+
523+
// #nosec
524+
const MinimumViableTokenSchema = `{
525+
"$schema": "http://json-schema.org/draft-07/schema#",
526+
"type": "object",
527+
"properties": {
528+
"aud": {
529+
"type": ["string", "array"]
530+
},
531+
"exp": {
532+
"type": "integer"
533+
},
534+
"jti": {
535+
"type": "string"
536+
},
537+
"iat": {
538+
"type": "integer"
539+
},
540+
"iss": {
541+
"type": "string"
542+
},
543+
"nbf": {
544+
"type": "integer"
545+
},
546+
"sub": {
547+
"type": "string"
548+
},
549+
"email": {
550+
"type": "string"
551+
},
552+
"phone": {
553+
"type": "string"
554+
},
555+
"app_metadata": {
556+
"type": "object",
557+
"additionalProperties": true
558+
},
559+
"user_metadata": {
560+
"type": "object",
561+
"additionalProperties": true
562+
},
563+
"role": {
564+
"type": "string"
565+
},
566+
"aal": {
567+
"type": "string"
568+
},
569+
"amr": {
570+
"type": "array",
571+
"items": {
572+
"type": "object"
573+
}
574+
},
575+
"session_id": {
576+
"type": "string"
577+
}
578+
},
579+
"required": ["aud", "exp", "iat", "sub", "email", "phone", "role", "aal", "session_id", "is_anonymous"]
580+
}`

internal/hooks/hookshttp/hookshttp.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,11 @@ func New(opts ...Option) *Dispatcher {
8282

8383
func (o *Dispatcher) Dispatch(
8484
ctx context.Context,
85-
cfg conf.ExtensibilityPointConfiguration,
85+
cfg *conf.ExtensibilityPointConfiguration,
8686
req any,
8787
res any,
8888
) error {
89-
data, err := o.RunHTTPHook(ctx, cfg, req)
89+
data, err := o.runHTTPHook(ctx, cfg, req)
9090
if err != nil {
9191
return err
9292
}
@@ -99,9 +99,9 @@ func (o *Dispatcher) Dispatch(
9999
return nil
100100
}
101101

102-
func (o *Dispatcher) RunHTTPHook(
102+
func (o *Dispatcher) runHTTPHook(
103103
ctx context.Context,
104-
hookConfig conf.ExtensibilityPointConfiguration,
104+
hookConfig *conf.ExtensibilityPointConfiguration,
105105
input any,
106106
) ([]byte, error) {
107107
client := http.Client{

internal/hooks/hookshttp/hookshttp_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ func TestDispatch(t *testing.T) {
267267
}
268268

269269
res := M{}
270-
err := dr.Dispatch(testCtx, cfg, tc.req, &res)
270+
err := dr.Dispatch(testCtx, &cfg, tc.req, &res)
271271
if tc.err != nil {
272272
require.Error(t, err)
273273
require.Equal(t, tc.err, err)

internal/hooks/hookspgfunc/hookspgfunc.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
"github.com/supabase/auth/internal/api/apierrors"
1010
"github.com/supabase/auth/internal/conf"
11+
"github.com/supabase/auth/internal/hooks/hookserrors"
1112
"github.com/supabase/auth/internal/storage"
1213
)
1314

@@ -47,12 +48,11 @@ func New(db *storage.Connection, opts ...Option) *Dispatcher {
4748

4849
func (o *Dispatcher) Dispatch(
4950
ctx context.Context,
50-
cfg conf.ExtensibilityPointConfiguration,
51+
cfg *conf.ExtensibilityPointConfiguration,
5152
tx *storage.Connection,
52-
req any,
53-
res any,
53+
req, res any,
5454
) error {
55-
data, err := o.RunPostgresHook(ctx, cfg, tx, req)
55+
data, err := o.runPostgresHook(ctx, *cfg, tx, req)
5656
if err != nil {
5757
return err
5858
}
@@ -65,7 +65,7 @@ func (o *Dispatcher) Dispatch(
6565
return nil
6666
}
6767

68-
func (o *Dispatcher) RunPostgresHook(
68+
func (o *Dispatcher) runPostgresHook(
6969
ctx context.Context,
7070
hookConfig conf.ExtensibilityPointConfiguration,
7171
tx *storage.Connection,
@@ -108,5 +108,8 @@ func (o *Dispatcher) RunPostgresHook(
108108
return nil, err
109109
}
110110
}
111+
if err := hookserrors.Check(response); err != nil {
112+
return nil, err
113+
}
111114
return response, nil
112115
}

internal/hooks/hookspgfunc/hookspgfunc_test.go

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,38 @@ func TestDispatch(t *testing.T) {
182182
end; $$ language plpgsql;`,
183183
errStr: "500: Error unmarshaling JSON output.",
184184
},
185+
186+
{
187+
desc: "fail - returned error",
188+
cfg: conf.ExtensibilityPointConfiguration{
189+
URI: `pg-functions://postgres/auth/v0pgfunc_test_return_input`,
190+
HookName: `"auth"."v0pgfunc_test_return_input"`,
191+
},
192+
req: M{"error": M{"message": "failed"}},
193+
sql: `
194+
create or replace function v0pgfunc_test_return_input(input jsonb)
195+
returns json as $$
196+
begin
197+
return input;
198+
end; $$ language plpgsql;`,
199+
errStr: "500: failed",
200+
},
201+
202+
{
203+
desc: "fail - returned error with status",
204+
cfg: conf.ExtensibilityPointConfiguration{
205+
URI: `pg-functions://postgres/auth/v0pgfunc_test_return_input`,
206+
HookName: `"auth"."v0pgfunc_test_return_input"`,
207+
},
208+
req: M{"error": M{"message": "failed", "http_code": 403}},
209+
sql: `
210+
create or replace function v0pgfunc_test_return_input(input jsonb)
211+
returns json as $$
212+
begin
213+
return input;
214+
end; $$ language plpgsql;`,
215+
errStr: "403: failed",
216+
},
185217
}
186218

187219
for idx, tc := range cases {
@@ -210,7 +242,7 @@ func TestDispatch(t *testing.T) {
210242
tx := tc.tx
211243
cfg := tc.cfg
212244
res := M{}
213-
err := dr.Dispatch(testCtx, cfg, tx, tc.req, &res)
245+
err := dr.Dispatch(testCtx, &cfg, tx, tc.req, &res)
214246
if tc.err != nil {
215247
require.Error(t, err)
216248
require.Equal(t, tc.err, err)

0 commit comments

Comments
 (0)