Skip to content

Commit 903e623

Browse files
author
Chris Stockton
committed
feat: add pkg internal/e2e/e2ehooks
This is to prepare for adding hook calls to the API package. In addition the test coverage of all e2e was increased to 100%.
1 parent bd37fe2 commit 903e623

File tree

6 files changed

+484
-6
lines changed

6 files changed

+484
-6
lines changed

internal/e2e/e2e.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,14 @@ var (
1616
configPath string
1717
)
1818

19+
var isTesting func() bool = testing.Testing
20+
1921
func init() {
20-
if testing.Testing() {
22+
initPackage()
23+
}
24+
25+
func initPackage() {
26+
if isTesting() {
2127
_, thisFile, _, _ := runtime.Caller(0)
2228
projectRoot = filepath.Join(filepath.Dir(thisFile), "../..")
2329
configPath = filepath.Join(GetProjectRoot(), "hack", "test.env")

internal/e2e/e2e_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,4 +94,27 @@ func TestUtils(t *testing.T) {
9494
t.Fatal("exp non-nil err")
9595
}
9696
}()
97+
98+
// block init from main()
99+
func() {
100+
restore := isTesting
101+
defer func() {
102+
isTesting = restore
103+
}()
104+
isTesting = func() bool { return false }
105+
106+
var errStr string
107+
func() {
108+
defer func() {
109+
errStr = recover().(string)
110+
}()
111+
112+
initPackage()
113+
}()
114+
115+
exp := "package e2e may not be used in a main package"
116+
if errStr != exp {
117+
t.Fatalf("exp %v; got %v", exp, errStr)
118+
}
119+
}()
97120
}

internal/e2e/e2eapi/e2eapi.go

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,18 @@ func Do(
9090
if err != nil {
9191
return err
9292
}
93-
if err := json.Unmarshal(data, res); err != nil {
94-
return err
93+
if len(data) > 0 {
94+
if err := json.Unmarshal(data, res); err != nil {
95+
return err
96+
}
9597
}
9698
return nil
9799
}
98100

101+
const responseLimit = 1e6
102+
103+
var defaultClient = http.DefaultClient
104+
99105
func do(
100106
ctx context.Context,
101107
method string,
@@ -113,7 +119,7 @@ func do(
113119
h.Add("Content-Type", "application/json")
114120
h.Add("Accept", "application/json")
115121

116-
httpRes, err := http.DefaultClient.Do(httpReq)
122+
httpRes, err := defaultClient.Do(httpReq)
117123
if err != nil {
118124
return nil, err
119125
}
@@ -124,7 +130,7 @@ func do(
124130
return nil, nil
125131

126132
case sc >= 400:
127-
data, err := io.ReadAll(io.LimitReader(httpRes.Body, 1e8))
133+
data, err := io.ReadAll(io.LimitReader(httpRes.Body, responseLimit))
128134
if err != nil {
129135
return nil, err
130136
}
@@ -142,7 +148,7 @@ func do(
142148
return nil, err
143149

144150
default:
145-
data, err := io.ReadAll(io.LimitReader(httpRes.Body, 1e8))
151+
data, err := io.ReadAll(io.LimitReader(httpRes.Body, responseLimit))
146152
if err != nil {
147153
return nil, err
148154
}

internal/e2e/e2eapi/e2eapi_test.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@ package e2eapi
22

33
import (
44
"context"
5+
"errors"
6+
"io"
57
"net/http"
8+
"net/http/httptest"
69
"testing"
10+
"testing/iotest"
711
"time"
812

913
"github.com/gofrs/uuid"
@@ -116,4 +120,56 @@ func TestDo(t *testing.T) {
116120
}
117121
require.ErrorContains(t, err, "unsupported protocol")
118122
}
123+
124+
func() {
125+
hr := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
126+
w.WriteHeader(http.StatusNoContent)
127+
})
128+
129+
ts := httptest.NewServer(hr)
130+
defer ts.Close()
131+
132+
err := Do(ctx, http.MethodPost, ts.URL, nil, nil)
133+
if err != nil {
134+
t.Fatalf("exp nil err; got %v", err)
135+
}
136+
}()
137+
138+
for _, statusCode := range []int{http.StatusBadRequest, http.StatusOK} {
139+
func() {
140+
sentinel := errors.New("sentinel")
141+
rtFn := roundTripperFunc(func(req *http.Request) (*http.Response, error) {
142+
res, err := http.DefaultClient.Do(req)
143+
if err != nil {
144+
return nil, err
145+
}
146+
res.Body = io.NopCloser(iotest.ErrReader(sentinel))
147+
return res, nil
148+
})
149+
150+
prev := defaultClient
151+
defer func() {
152+
defaultClient = prev
153+
}()
154+
defaultClient = new(http.Client)
155+
defaultClient.Transport = rtFn
156+
157+
hr := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
158+
w.WriteHeader(statusCode)
159+
})
160+
161+
ts := httptest.NewServer(hr)
162+
defer ts.Close()
163+
164+
err := Do(ctx, http.MethodPost, ts.URL, nil, nil)
165+
require.Error(t, err)
166+
require.Equal(t, sentinel, err)
167+
}()
168+
}
169+
}
170+
171+
type roundTripperFunc func(*http.Request) (*http.Response, error)
172+
173+
func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) {
174+
return f(req)
119175
}

internal/e2e/e2ehooks/e2ehooks.go

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
// Package e2ehooks provides utilities for end-to-end testing of hooks.
2+
package e2ehooks
3+
4+
import (
5+
"bytes"
6+
"encoding/json"
7+
"io"
8+
"net/http"
9+
"net/http/httptest"
10+
"net/http/httputil"
11+
"slices"
12+
"sync"
13+
14+
"github.com/supabase/auth/internal/conf"
15+
"github.com/supabase/auth/internal/e2e/e2eapi"
16+
"github.com/supabase/auth/internal/hooks/v0hooks"
17+
)
18+
19+
type Instance struct {
20+
*e2eapi.Instance
21+
22+
HookServer *httptest.Server
23+
HookRecorder *HookRecorder
24+
}
25+
26+
func (o *Instance) Close() error {
27+
defer o.Instance.Close()
28+
defer o.HookServer.Close()
29+
return nil
30+
}
31+
32+
func New(globalCfg *conf.GlobalConfiguration) (*Instance, error) {
33+
hookRec := NewHookRecorder()
34+
hookSrv := httptest.NewServer(hookRec)
35+
hookRec.Register(&globalCfg.Hook, hookSrv.URL)
36+
37+
test, err := e2eapi.New(globalCfg)
38+
if err != nil {
39+
defer hookSrv.Close()
40+
41+
return nil, err
42+
}
43+
44+
o := &Instance{
45+
Instance: test,
46+
HookServer: hookSrv,
47+
HookRecorder: hookRec,
48+
}
49+
return o, nil
50+
}
51+
52+
func HandleSuccess() http.Handler {
53+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
54+
w.Header().Add("content-type", "application/json")
55+
_, _ = io.WriteString(w, "{}")
56+
})
57+
}
58+
59+
type Hook struct {
60+
mu sync.Mutex
61+
name v0hooks.Name
62+
calls []*HookCall
63+
64+
hr http.Handler
65+
}
66+
67+
func NewHook(name v0hooks.Name) *Hook {
68+
o := &Hook{
69+
name: name,
70+
}
71+
o.SetHandler(HandleSuccess())
72+
return o
73+
}
74+
75+
func (o *Hook) ClearCalls() {
76+
o.mu.Lock()
77+
defer o.mu.Unlock()
78+
o.calls = nil
79+
}
80+
81+
func (o *Hook) GetCalls() []*HookCall {
82+
o.mu.Lock()
83+
defer o.mu.Unlock()
84+
return slices.Clone(o.calls)
85+
}
86+
87+
func (o *Hook) SetHandler(hr http.Handler) {
88+
o.mu.Lock()
89+
defer o.mu.Unlock()
90+
o.hr = hr
91+
}
92+
93+
func (o *Hook) ServeHTTP(w http.ResponseWriter, r *http.Request) {
94+
o.mu.Lock()
95+
defer o.mu.Unlock()
96+
97+
dump, _ := httputil.DumpRequest(r, true)
98+
body, err := io.ReadAll(r.Body)
99+
if err != nil {
100+
code := http.StatusInternalServerError
101+
http.Error(w, http.StatusText(code), code)
102+
return
103+
}
104+
r.Body = io.NopCloser(bytes.NewReader(body))
105+
106+
hc := &HookCall{
107+
Dump: string(dump),
108+
Body: string(body),
109+
Header: r.Header.Clone(),
110+
}
111+
o.calls = append(o.calls, hc)
112+
113+
o.hr.ServeHTTP(w, r)
114+
}
115+
116+
type HookCall struct {
117+
Header http.Header
118+
Body string
119+
Dump string
120+
}
121+
122+
func (o *HookCall) Unmarshal(v any) error {
123+
return json.Unmarshal([]byte(o.Body), v)
124+
}
125+
126+
type HookRecorder struct {
127+
mux *http.ServeMux
128+
BeforeUserCreated *Hook
129+
AfterUserCreated *Hook
130+
CustomizeAccessToken *Hook
131+
MFAVerification *Hook
132+
PasswordVerification *Hook
133+
SendEmail *Hook
134+
SendSMS *Hook
135+
}
136+
137+
func NewHookRecorder() *HookRecorder {
138+
o := &HookRecorder{
139+
mux: http.NewServeMux(),
140+
BeforeUserCreated: NewHook(v0hooks.BeforeUserCreated),
141+
AfterUserCreated: NewHook(v0hooks.AfterUserCreated),
142+
CustomizeAccessToken: NewHook(v0hooks.CustomizeAccessToken),
143+
MFAVerification: NewHook(v0hooks.MFAVerification),
144+
PasswordVerification: NewHook(v0hooks.PasswordVerification),
145+
SendEmail: NewHook(v0hooks.SendEmail),
146+
SendSMS: NewHook(v0hooks.SendSMS),
147+
}
148+
149+
o.mux.HandleFunc("POST /hooks/{hook}", func(w http.ResponseWriter, r *http.Request) {
150+
//exhaustive:ignore
151+
switch v0hooks.Name(r.PathValue("hook")) {
152+
case v0hooks.BeforeUserCreated:
153+
o.BeforeUserCreated.ServeHTTP(w, r)
154+
155+
case v0hooks.AfterUserCreated:
156+
o.AfterUserCreated.ServeHTTP(w, r)
157+
158+
case v0hooks.CustomizeAccessToken:
159+
o.CustomizeAccessToken.ServeHTTP(w, r)
160+
161+
case v0hooks.MFAVerification:
162+
o.MFAVerification.ServeHTTP(w, r)
163+
164+
case v0hooks.PasswordVerification:
165+
o.PasswordVerification.ServeHTTP(w, r)
166+
167+
case v0hooks.SendEmail:
168+
o.SendEmail.ServeHTTP(w, r)
169+
170+
case v0hooks.SendSMS:
171+
o.SendSMS.ServeHTTP(w, r)
172+
173+
default:
174+
http.NotFound(w, r)
175+
}
176+
})
177+
return o
178+
}
179+
180+
func (o *HookRecorder) Register(
181+
hookCfg *conf.HookConfiguration,
182+
baseURL string,
183+
) {
184+
set := func(cfg *conf.ExtensibilityPointConfiguration, name v0hooks.Name) {
185+
*cfg = conf.ExtensibilityPointConfiguration{
186+
Enabled: true,
187+
URI: baseURL + "/hooks/" + string(name),
188+
}
189+
}
190+
set(&hookCfg.BeforeUserCreated, v0hooks.BeforeUserCreated)
191+
set(&hookCfg.AfterUserCreated, v0hooks.AfterUserCreated)
192+
set(&hookCfg.CustomAccessToken, v0hooks.CustomizeAccessToken)
193+
set(&hookCfg.MFAVerificationAttempt, v0hooks.MFAVerification)
194+
set(&hookCfg.PasswordVerificationAttempt, v0hooks.PasswordVerification)
195+
set(&hookCfg.SendEmail, v0hooks.SendEmail)
196+
set(&hookCfg.SendSMS, v0hooks.SendSMS)
197+
}
198+
199+
func (o *HookRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
200+
o.mux.ServeHTTP(w, r)
201+
}

0 commit comments

Comments
 (0)