diff --git a/internal/api/anonymous.go b/internal/api/anonymous.go index 33aa04aee..69679b9bc 100644 --- a/internal/api/anonymous.go +++ b/internal/api/anonymous.go @@ -30,6 +30,9 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + if err := a.triggerBeforeUserCreated(r, db, newUser); err != nil { + return err + } var grantParams models.GrantParams grantParams.FillGrantParams(r) diff --git a/internal/api/e2e_test.go b/internal/api/e2e_test.go new file mode 100644 index 000000000..8902a7c06 --- /dev/null +++ b/internal/api/e2e_test.go @@ -0,0 +1,664 @@ +package api_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "math/rand" + "net/http" + "slices" + "strings" + "testing" + "time" + + "github.com/gofrs/uuid" + jwt "github.com/golang-jwt/jwt/v5" + "github.com/pquerna/otp/totp" + "github.com/stretchr/testify/require" + "github.com/supabase/auth/internal/api" + "github.com/supabase/auth/internal/conf" + "github.com/supabase/auth/internal/e2e" + "github.com/supabase/auth/internal/e2e/e2eapi" + "github.com/supabase/auth/internal/e2e/e2ehooks" + "github.com/supabase/auth/internal/hooks/v0hooks" + "github.com/supabase/auth/internal/models" +) + +type M = map[string]any + +func genEmail() string { + return "e2etesthooks_" + uuid.Must(uuid.NewV4()).String() + "@localhost" +} + +func genPhone() string { + var sb strings.Builder + sb.WriteString("1") + for i := 0; i < 9; i++ { + // #nosec G404 + sb.WriteString(fmt.Sprintf("%d", rand.Intn(9))) + } + phone := sb.String() + return phone +} + +func TestE2EHooks(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + globalCfg := e2e.Must(e2e.Config()) + globalCfg.External.AnonymousUsers.Enabled = true + globalCfg.External.Phone.Enabled = true + globalCfg.MFA.Phone.EnrollEnabled = true + globalCfg.MFA.TOTP.EnrollEnabled = true + globalCfg.MFA.Phone.VerifyEnabled = true + + inst, err := e2ehooks.New(globalCfg) + require.NoError(t, err) + defer inst.Close() + + apiSrv := inst.APIServer + hookRec := inst.HookRecorder + + runBeforeUserCreated := func(t *testing.T, expUser *models.User) *models.User { + var latest *models.User + t.Run("BeforeUserCreated", func(t *testing.T) { + defer hookRec.BeforeUserCreated.ClearCalls() + + calls := hookRec.BeforeUserCreated.GetCalls() + require.Equal(t, 1, len(calls)) + call := calls[0] + + hookReq := &v0hooks.BeforeUserCreatedInput{} + err := call.Unmarshal(hookReq) + require.NoError(t, err) + require.Equal(t, v0hooks.BeforeUserCreated, hookReq.Metadata.Name) + + u := hookReq.User + require.Equal(t, expUser.ID, u.ID) + require.Equal(t, expUser.Aud, u.Aud) + require.Equal(t, expUser.Email, u.Email) + require.Equal(t, expUser.AppMetaData, u.AppMetaData) + + require.True(t, u.CreatedAt.IsZero()) + require.True(t, u.UpdatedAt.IsZero()) + + err = expUser.Confirm(inst.Conn) + require.NoError(t, err) + + latest, err = models.FindUserByID(inst.Conn, expUser.ID) + require.NoError(t, err) + require.NotNil(t, latest) + }) + return latest + } + + getAccessToken := func( + t *testing.T, + email, pass string, + ) *api.AccessTokenResponse { + req := &api.PasswordGrantParams{ + Email: email, + Password: pass, + } + + res := new(api.AccessTokenResponse) + err := e2eapi.Do(ctx, http.MethodPost, apiSrv.URL+"/token?grant_type=password", req, res) + require.NoError(t, err) + return res + } + + // Basic tests for user hooks + t.Run("UserHooks", func(t *testing.T) { + + t.Run("SignupEmail", func(t *testing.T) { + defer hookRec.BeforeUserCreated.ClearCalls() + + email := genEmail() + req := &api.SignupParams{ + Email: email, + Password: "password", + } + res := new(models.User) + err := e2eapi.Do(ctx, http.MethodPost, apiSrv.URL+"/signup", req, res) + require.NoError(t, err) + require.Equal(t, email, res.Email.String()) + + runBeforeUserCreated(t, res) + }) + + t.Run("SignupPhone", func(t *testing.T) { + defer hookRec.BeforeUserCreated.ClearCalls() + + phone := genPhone() + req := &api.SignupParams{ + Phone: phone, + Password: "password", + } + res := new(models.User) + err := e2eapi.Do(ctx, http.MethodPost, apiSrv.URL+"/signup", req, res) + require.NoError(t, err) + require.Equal(t, phone, res.Phone.String()) + + runBeforeUserCreated(t, res) + }) + + t.Run("SignupAnonymously", func(t *testing.T) { + defer hookRec.BeforeUserCreated.ClearCalls() + + req := &api.SignupParams{} + res := new(api.AccessTokenResponse) + err := e2eapi.Do(ctx, http.MethodPost, apiSrv.URL+"/signup", req, res) + require.NoError(t, err) + + runBeforeUserCreated(t, res.User) + }) + + t.Run("ExternalCallback", func(t *testing.T) { + defer hookRec.BeforeUserCreated.ClearCalls() + + req := &api.SignupParams{} + res := new(api.AccessTokenResponse) + err := e2eapi.Do(ctx, http.MethodPost, apiSrv.URL+"/signup", req, res) + require.NoError(t, err) + + runBeforeUserCreated(t, res.User) + }) + + t.Run("AdminEndpoints", func(t *testing.T) { + t.Run("Invite", func(t *testing.T) { + defer hookRec.BeforeUserCreated.ClearCalls() + + email := genEmail() + req := &api.InviteParams{ + Email: email, + } + res := new(models.User) + + body := new(bytes.Buffer) + err := json.NewEncoder(body).Encode(req) + require.NoError(t, err) + + httpReq, err := http.NewRequestWithContext( + ctx, "POST", "/invite", body) + require.NoError(t, err) + + httpRes, err := inst.DoAdmin(httpReq) + require.NoError(t, err) + + err = json.NewDecoder(httpRes.Body).Decode(res) + require.NoError(t, err) + + runBeforeUserCreated(t, res) + }) + + t.Run("AdminGenerateLink", func(t *testing.T) { + + t.Run("SignupVerification", func(t *testing.T) { + defer hookRec.BeforeUserCreated.ClearCalls() + + email := genEmail() + req := &api.GenerateLinkParams{ + Type: "signup", + Email: email, + Password: "pass1234", + } + res := new(api.GenerateLinkResponse) + + body := new(bytes.Buffer) + err := json.NewEncoder(body).Encode(req) + require.NoError(t, err) + + httpReq, err := http.NewRequestWithContext( + ctx, "POST", "/admin/generate_link", body) + require.NoError(t, err) + + httpRes, err := inst.DoAdmin(httpReq) + require.NoError(t, err) + require.Equal(t, 200, httpRes.StatusCode) + + err = json.NewDecoder(httpRes.Body).Decode(res) + require.NoError(t, err) + + runBeforeUserCreated(t, &res.User) + }) + + t.Run("InviteVerification", func(t *testing.T) { + defer hookRec.BeforeUserCreated.ClearCalls() + + email := genEmail() + req := &api.GenerateLinkParams{ + Type: "invite", + Email: email, + } + res := new(api.GenerateLinkResponse) + + body := new(bytes.Buffer) + err := json.NewEncoder(body).Encode(req) + require.NoError(t, err) + + httpReq, err := http.NewRequestWithContext( + ctx, "POST", "/admin/generate_link", body) + require.NoError(t, err) + + httpRes, err := inst.DoAdmin(httpReq) + require.NoError(t, err) + require.Equal(t, 200, httpRes.StatusCode) + + err = json.NewDecoder(httpRes.Body).Decode(res) + require.NoError(t, err) + + runBeforeUserCreated(t, &res.User) + }) + }) + }) + }) + + t.Run("MFAVerificationAttempt", func(t *testing.T) { + defer hookRec.MFAVerification.ClearCalls() + + type flowResult struct { + factorRes *api.EnrollFactorResponse + challengeRes *api.ChallengeFactorResponse + mfaUser *models.User + mfaUserAccessToken *api.AccessTokenResponse + } + + runMFAFlow := func(t *testing.T) *flowResult { + factorRes := new(api.EnrollFactorResponse) + challengeRes := new(api.ChallengeFactorResponse) + mfaUser := new(models.User) + mfaUserAccessToken := new(api.AccessTokenResponse) + + t.Run("MFAFlow", func(t *testing.T) { + t.Run("Signup", func(t *testing.T) { + email := genEmail() + const password = "password" + req := &api.SignupParams{ + Email: email, + Password: password, + } + err := e2eapi.Do(ctx, http.MethodPost, apiSrv.URL+"/signup", req, mfaUser) + require.NoError(t, err) + require.Equal(t, email, mfaUser.Email.String()) + + mfaUser = runBeforeUserCreated(t, mfaUser) + mfaUserAccessToken = getAccessToken(t, string(mfaUser.Email), password) + + phone := genPhone() + domain := strings.Split(email, "@")[1] + + // enroll factor + t.Run("MFAEnroll", func(t *testing.T) { + req := &api.EnrollFactorParams{ + FriendlyName: "totp_" + email, + Phone: phone, + Issuer: domain, + FactorType: models.TOTP, + } + + body := new(bytes.Buffer) + err = json.NewEncoder(body).Encode(req) + require.NoError(t, err) + + httpReq, err := http.NewRequestWithContext( + ctx, "POST", "/factors/", body) + require.NoError(t, err) + + httpRes, err := inst.DoAuth(httpReq, mfaUserAccessToken.Token) + require.NoError(t, err) + require.Equal(t, 200, httpRes.StatusCode) + + err = json.NewDecoder(httpRes.Body).Decode(factorRes) + require.NoError(t, err) + }) + + // challenge factor + t.Run("MFAChallenge", func(t *testing.T) { + req := &models.Factor{ + ID: factorRes.ID, + } + + body := new(bytes.Buffer) + err = json.NewEncoder(body).Encode(req) + require.NoError(t, err) + + url := fmt.Sprintf("/factors/%v/challenge", factorRes.ID) + httpReq, err := http.NewRequestWithContext( + ctx, "POST", url, body) + require.NoError(t, err) + + httpRes, err := inst.DoAuth(httpReq, mfaUserAccessToken.Token) + require.NoError(t, err) + require.Equal(t, 200, httpRes.StatusCode) + + err = json.NewDecoder(httpRes.Body).Decode(challengeRes) + require.NoError(t, err) + }) + }) + }) + return &flowResult{ + factorRes: factorRes, + challengeRes: challengeRes, + mfaUser: mfaUser, + mfaUserAccessToken: mfaUserAccessToken, + } + } + + t.Run("MFAVerifySuccess", func(t *testing.T) { + defer hookRec.MFAVerification.ClearCalls() + + flowRes := runMFAFlow(t) + verifyRes := new(api.AccessTokenResponse) + + mfaCode, err := totp.GenerateCode(flowRes.factorRes.TOTP.Secret, time.Now().UTC()) + require.NoError(t, err) + + req := &api.VerifyFactorParams{ + ChallengeID: flowRes.challengeRes.ID, + Code: mfaCode, + } + + body := new(bytes.Buffer) + err = json.NewEncoder(body).Encode(req) + require.NoError(t, err) + + url := fmt.Sprintf("/factors/%v/verify", flowRes.factorRes.ID) + httpReq, err := http.NewRequestWithContext( + ctx, "POST", url, body) + require.NoError(t, err) + + httpRes, err := inst.DoAuth(httpReq, flowRes.mfaUserAccessToken.Token) + require.NoError(t, err) + require.Equal(t, 200, httpRes.StatusCode) + + // verify the mfa was accepted + err = json.NewDecoder(httpRes.Body).Decode(verifyRes) + require.NoError(t, err) + require.NotEmpty(t, verifyRes.Token) + + calls := hookRec.MFAVerification.GetCalls() + require.Equal(t, 1, len(calls)) + call := calls[0] + + hookReq := M{} + err = call.Unmarshal(&hookReq) + require.NoError(t, err) + + // verify hook request + require.Equal(t, flowRes.factorRes.ID.String(), hookReq["factor_id"]) + require.Equal(t, flowRes.factorRes.Type, hookReq["factor_type"]) + require.Equal(t, flowRes.mfaUser.ID.String(), hookReq["user_id"]) + require.Equal(t, true, hookReq["valid"]) + }) + + t.Run("MFAVerifyFailure", func(t *testing.T) { + defer hookRec.MFAVerification.ClearCalls() + + const errorMsg = "sentinel error message" + { + hr := e2ehooks.HandleJSON(M{ + "decision": "reject", + "message": errorMsg, + }) + hookRec.MFAVerification.SetHandler(hr) + } + + flowRes := runMFAFlow(t) + errorRes := new(api.HTTPError) + + mfaCode, err := totp.GenerateCode(flowRes.factorRes.TOTP.Secret, time.Now().UTC()) + require.NoError(t, err) + + req := &api.VerifyFactorParams{ + ChallengeID: flowRes.challengeRes.ID, + Code: mfaCode, + } + + body := new(bytes.Buffer) + err = json.NewEncoder(body).Encode(req) + require.NoError(t, err) + + url := fmt.Sprintf("/factors/%v/verify", flowRes.factorRes.ID) + httpReq, err := http.NewRequestWithContext( + ctx, "POST", url, body) + require.NoError(t, err) + + httpRes, err := inst.DoAuth(httpReq, flowRes.mfaUserAccessToken.Token) + require.NoError(t, err) + require.Equal(t, 403, httpRes.StatusCode) + + // verify the mfa rejection + err = json.NewDecoder(httpRes.Body).Decode(errorRes) + require.NoError(t, err) + require.Equal(t, 403, errorRes.HTTPStatus) + require.Equal(t, "mfa_verification_rejected", errorRes.ErrorCode) + require.Equal(t, errorMsg, errorRes.Message) + + calls := hookRec.MFAVerification.GetCalls() + require.Equal(t, 1, len(calls)) + call := calls[0] + + hookReq := M{} + err = call.Unmarshal(&hookReq) + require.NoError(t, err) + + // verify hook request + require.Equal(t, flowRes.factorRes.ID.String(), hookReq["factor_id"]) + require.Equal(t, flowRes.factorRes.Type, hookReq["factor_type"]) + require.Equal(t, flowRes.mfaUser.ID.String(), hookReq["user_id"]) + require.Equal(t, true, hookReq["valid"]) + }) + }) + // Basic tests for CustomizeAccessToken + t.Run("CustomizeAccessToken", func(t *testing.T) { + defer hookRec.CustomizeAccessToken.ClearCalls() + + // setup user to test with + var currentUser *models.User + { + email := genEmail() + req := &api.SignupParams{ + Email: email, + Password: "password", + } + res := new(models.User) + err := e2eapi.Do(ctx, http.MethodPost, apiSrv.URL+"/signup", req, res) + require.NoError(t, err) + require.Equal(t, email, res.Email.String()) + + currentUser = runBeforeUserCreated(t, res) + require.NotNil(t, currentUser) + hookRec.CustomizeAccessToken.ClearCalls() + } + + copyMap := func(t *testing.T, m M) (out M) { + b, err := json.Marshal(m) + require.NoError(t, err) + err = json.Unmarshal(b, &out) + require.NoError(t, err) + return out + } + checkClaims := func(t *testing.T, in, out M, exclude ...string) { + if aud, ok := in["aud"].([]any); ok && len(aud) > 0 { + require.Equal(t, aud[0].(string), out["aud"]) + } + if aud, ok := in["aud"].(string); ok { + require.Equal(t, aud, out["aud"]) + } + + for _, k := range []string{ + "iss", + "sub", + "exp", + "iat", + "aal", + "role", + "amr", + "session_id", + "is_anonymous", + "app_metadata", + "user_metadata", + "phone", + "email", + } { + if !slices.Contains(exclude, k) { + require.Equal(t, in[k], out[k]) + } + } + } + + cases := []struct { + desc string + from func(claimsIn M) (claimsOut M) + errStr string + check func( + t *testing.T, + claimsIn, claimsOut M, + ) + }{ + { + desc: `claims field missing`, + from: func(in M) M { return M{} }, + errStr: "500: output claims field is missing", + }, + + { + desc: `claims field missing with top level keys`, + from: func(in M) M { + return M{ + "myclaim": "aaa", + "other_claim": "bbb", + } + }, + errStr: "500: output claims field is missing", + }, + + { + desc: `claims field nil`, + from: func(in M) M { return M{"claims": nil} }, + errStr: "500: output claims do not conform to the expected schema", + }, + + { + desc: `claims field empty`, + from: func(in M) M { return M{"claims": M{}} }, + errStr: "500: output claims do not conform to the expected schema", + }, + + { + desc: `add app_metadata claims`, + from: func(in M) M { + out := copyMap(t, in) + out["claims"].(M)["app_metadata"].(M)["bool_true"] = true + out["claims"].(M)["app_metadata"].(M)["string_hello"] = "hello" + return out + }, + check: func( + t *testing.T, + in, out M, + ) { + checkClaims(t, in, out, "app_metadata") + + for k := range in { + if k == "app_metadata" { + require.Equal(t, + out["app_metadata"].(M)["bool_true"], + true, + ) + require.Equal(t, + out["app_metadata"].(M)["string_hello"], + "hello", + ) + continue + } + } + }, + }, + } + + for _, tc := range cases { + t.Run(string(tc.desc), func(t *testing.T) { + defer hookRec.CustomizeAccessToken.ClearCalls() + + var claimsIn, claimsOut M + hr := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("content-type", "application/json") + w.WriteHeader(http.StatusOK) + + err := json.NewDecoder(r.Body).Decode(&claimsIn) + require.NoError(t, err) + + claimsOut = tc.from(copyMap(t, claimsIn)) + err = json.NewEncoder(w).Encode(claimsOut) + require.NoError(t, err) + }) + + hookRec.CustomizeAccessToken.ClearCalls() + hookRec.CustomizeAccessToken.SetHandler(hr) + req := &api.PasswordGrantParams{ + Email: string(currentUser.Email), + Password: "password", + } + + res := new(api.AccessTokenResponse) + err := e2eapi.Do(ctx, http.MethodPost, apiSrv.URL+"/token?grant_type=password", req, res) + + // always verify the hook request before checking response + { + calls := hookRec.CustomizeAccessToken.GetCalls() + require.Equal(t, 1, len(calls)) + call := calls[0] + + hookReq := &v0hooks.CustomAccessTokenInput{} + err := call.Unmarshal(hookReq) + require.NoError(t, err) + require.Equal(t, currentUser.ID, hookReq.UserID) + require.Equal(t, currentUser.ID.String(), hookReq.Claims.Subject) + } + + // check if we expected an error + if tc.errStr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.errStr) + return + } + require.True(t, len(res.Token) > 0) + + // parse the token we got back + p := jwt.NewParser(jwt.WithValidMethods(globalCfg.JWT.ValidMethods)) + token, err := p.ParseWithClaims( + res.Token, + &api.AccessTokenClaims{}, + func(token *jwt.Token, + ) (any, error) { + if kid, ok := token.Header["kid"]; ok { + if kidStr, ok := kid.(string); ok { + return conf.FindPublicKeyByKid(kidStr, &globalCfg.JWT) + } + } + if alg, ok := token.Header["alg"]; ok { + if alg == jwt.SigningMethodHS256.Name { + // preserve backward compatibility for cases where the kid is not set + return []byte(globalCfg.JWT.Secret), nil + } + } + return nil, fmt.Errorf("missing kid") + }) + require.NoError(t, err) + + tokenClaims := M{} + { + b, err := json.Marshal(token.Claims) + require.NoError(t, err) + err = json.Unmarshal(b, &tokenClaims) + require.NoError(t, err) + } + + if tc.check != nil { + tc.check(t, claimsIn["claims"].(M), tokenClaims) + } + }) + } + }) +} diff --git a/internal/api/external.go b/internal/api/external.go index dfaa86a03..7e0ce527e 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -203,15 +203,24 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re } + targetUser := getTargetUser(ctx) + inviteToken := getInviteToken(ctx) + if targetUser == nil && inviteToken == "" { + if err := a.triggerBeforeUserCreatedExternal( + r, db, userData, providerType); err != nil { + return err + } + } + var user *models.User var token *AccessTokenResponse err = db.Transaction(func(tx *storage.Connection) error { var terr error - if targetUser := getTargetUser(ctx); targetUser != nil { + if targetUser != nil { if user, terr = a.linkIdentityToUser(r, ctx, tx, userData, providerType); terr != nil { return terr } - } else if inviteToken := getInviteToken(ctx); inviteToken != "" { + } else if inviteToken != "" { if user, terr = a.processInvite(r, tx, userData, inviteToken, providerType); terr != nil { return terr } @@ -334,6 +343,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. return nil, terr } user.Identities = append(user.Identities, *identity) + case models.AccountExists: user = decision.User identity = decision.Identities[0] diff --git a/internal/api/hooks.go b/internal/api/hooks.go new file mode 100644 index 000000000..f850ab1ad --- /dev/null +++ b/internal/api/hooks.go @@ -0,0 +1,105 @@ +package api + +import ( + "net/http" + "strings" + + "github.com/fatih/structs" + "github.com/supabase/auth/internal/api/apierrors" + "github.com/supabase/auth/internal/api/provider" + "github.com/supabase/auth/internal/hooks/v0hooks" + "github.com/supabase/auth/internal/models" + "github.com/supabase/auth/internal/storage" +) + +func (a *API) triggerBeforeUserCreated( + r *http.Request, + conn *storage.Connection, + user *models.User, +) error { + if !a.hooksMgr.Enabled(v0hooks.BeforeUserCreated) { + return nil + } + if err := checkTX(conn); err != nil { + return err + } + + req := v0hooks.NewBeforeUserCreatedInput(r, user) + res := new(v0hooks.BeforeUserCreatedOutput) + return a.hooksMgr.InvokeHook(conn, r, req, res) +} + +func (a *API) triggerBeforeUserCreatedExternal( + r *http.Request, + conn *storage.Connection, + userData *provider.UserProvidedData, + providerType string, +) error { + if !a.hooksMgr.Enabled(v0hooks.BeforeUserCreated) { + return nil + } + if err := checkTX(conn); err != nil { + return err + } + + ctx := r.Context() + aud := a.requestAud(ctx, r) + config := a.config + + var identityData map[string]interface{} + if userData.Metadata != nil { + identityData = structs.Map(userData.Metadata) + } + + var ( + err error + decision models.AccountLinkingResult + ) + err = a.db.Transaction(func(tx *storage.Connection) error { + decision, err = models.DetermineAccountLinking( + tx, config, userData.Emails, aud, + providerType, userData.Metadata.Subject) + if err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + if decision.Decision != models.CreateAccount { + return nil + } + if config.DisableSignup { + return apierrors.NewUnprocessableEntityError( + apierrors.ErrorCodeSignupDisabled, + "Signups not allowed for this instance") + } + + params := &SignupParams{ + Provider: providerType, + Email: decision.CandidateEmail.Email, + Aud: aud, + Data: identityData, + } + + isSSOUser := false + if strings.HasPrefix(decision.LinkingDomain, "sso:") { + isSSOUser = true + } + + user, err := params.ToUserModel(isSSOUser) + if err != nil { + return err + } + return a.triggerBeforeUserCreated(r, conn, user) +} + +func checkTX(conn *storage.Connection) error { + if conn.TX != nil { + return apierrors.NewInternalServerError( + "unable to trigger hooks during transaction") + } + return nil +} diff --git a/internal/api/hooks_test.go b/internal/api/hooks_test.go index 8195f1104..9670e431c 100644 --- a/internal/api/hooks_test.go +++ b/internal/api/hooks_test.go @@ -47,6 +47,19 @@ func TestHooks(t *testing.T) { defer api.db.Close() suite.Run(t, ts) + + t.Run("CheckTX", func(t *testing.T) { + require.NoError(t, checkTX(api.db)) + + err := api.db.Transaction(func(tx *storage.Connection) error { + require.Error(t, checkTX(tx)) + + err := checkTX(tx) + require.Error(t, err) + return nil + }) + require.NoError(t, err) + }) } func (ts *HooksTestSuite) SetupTest() { diff --git a/internal/api/invite.go b/internal/api/invite.go index 2c3f8cd7f..797931004 100644 --- a/internal/api/invite.go +++ b/internal/api/invite.go @@ -37,41 +37,38 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { if err != nil && !models.IsNotFoundError(err) { return apierrors.NewInternalServerError("Database error finding user").WithInternalError(err) } + if user != nil && user.IsConfirmed() { + return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) + } - err = db.Transaction(func(tx *storage.Connection) error { - if user != nil { - if user.IsConfirmed() { - return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) - } - } else { - signupParams := SignupParams{ - Email: params.Email, - Data: params.Data, - Aud: aud, - Provider: "email", - } + signupParams := SignupParams{ + Email: params.Email, + Data: params.Data, + Aud: aud, + Provider: "email", + } - // because params above sets no password, this method - // is not computationally hard so it can be used within - // a database transaction - user, err = signupParams.ToUserModel(false /* <- isSSOUser */) - if err != nil { - return err - } + user, err = signupParams.ToUserModel(false /* <- isSSOUser */) + if err != nil { + return err + } + if err := a.triggerBeforeUserCreated(r, db, user); err != nil { + return err + } - user, err = a.signupNewUser(tx, user) - if err != nil { - return err - } - identity, err := a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ - Subject: user.ID.String(), - Email: user.GetEmail(), - })) - if err != nil { - return err - } - user.Identities = []models.Identity{*identity} + err = db.Transaction(func(tx *storage.Connection) error { + user, err = a.signupNewUser(tx, user) + if err != nil { + return err + } + identity, err := a.createNewIdentity(tx, user, "email", structs.Map(provider.Claims{ + Subject: user.ID.String(), + Email: user.GetEmail(), + })) + if err != nil { + return err } + user.Identities = []models.Identity{*identity} if terr := models.NewAuditLogEntry(r, tx, adminUser, models.UserInvitedAction, "", map[string]interface{}{ "user_id": user.ID, diff --git a/internal/api/mail.go b/internal/api/mail.go index f90d9a74c..f06811db0 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -92,8 +92,12 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { hashedToken := crypto.GenerateTokenHash(params.Email, otp) - var signupUser *models.User - if params.Type == mail.SignupVerification && user == nil { + var ( + signupUser *models.User + inviteUser *models.User + ) + switch { + case params.Type == mail.SignupVerification && user == nil: signupParams := &SignupParams{ Email: params.Email, Password: params.Password, @@ -110,6 +114,25 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + if err := a.triggerBeforeUserCreated(r, db, signupUser); err != nil { + return err + } + + case params.Type == mail.InviteVerification && user == nil: + signupParams := &SignupParams{ + Email: params.Email, + Data: params.Data, + Provider: "email", + Aud: aud, + } + + inviteUser, err = signupParams.ToUserModel(false /* <- isSSOUser */) + if err != nil { + return err + } + if err := a.triggerBeforeUserCreated(r, db, inviteUser); err != nil { + return err + } } err = db.Transaction(func(tx *storage.Connection) error { @@ -138,22 +161,7 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { return apierrors.NewUnprocessableEntityError(apierrors.ErrorCodeEmailExists, DuplicateEmailMsg) } } else { - signupParams := &SignupParams{ - Email: params.Email, - Data: params.Data, - Provider: "email", - Aud: aud, - } - - // because params above sets no password, this - // method is not computationally hard so it can - // be used within a database transaction - user, terr = signupParams.ToUserModel(false /* <- isSSOUser */) - if terr != nil { - return terr - } - - user, terr = a.signupNewUser(tx, user) + user, terr = a.signupNewUser(tx, inviteUser) if terr != nil { return terr } diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 8ce34e52d..83422be18 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -631,9 +631,10 @@ func (a *API) verifyTOTPFactor(w http.ResponseWriter, r *http.Request, params *V if config.Hook.MFAVerificationAttempt.Enabled { input := v0hooks.MFAVerificationAttemptInput{ - UserID: user.ID, - FactorID: factor.ID, - Valid: valid, + UserID: user.ID, + FactorID: factor.ID, + FactorType: factor.FactorType, + Valid: valid, } output := v0hooks.MFAVerificationAttemptOutput{} diff --git a/internal/api/samlacs.go b/internal/api/samlacs.go index 3f64fc235..fb56e4cb5 100644 --- a/internal/api/samlacs.go +++ b/internal/api/samlacs.go @@ -281,12 +281,18 @@ func (a *API) handleSamlAcs(w http.ResponseWriter, r *http.Request) error { } } + providerType := "sso:" + ssoProvider.ID.String() + if err := a.triggerBeforeUserCreatedExternal( + r, db, &userProvidedData, providerType); err != nil { + return err + } + if err := db.Transaction(func(tx *storage.Connection) error { var terr error var user *models.User // accounts potentially created via SAML can contain non-unique email addresses in the auth.users table - if user, terr = a.createAccountFromExternalIdentity(tx, r, &userProvidedData, "sso:"+ssoProvider.ID.String()); terr != nil { + if user, terr = a.createAccountFromExternalIdentity(tx, r, &userProvidedData, providerType); terr != nil { return terr } if flowState != nil { diff --git a/internal/api/signup.go b/internal/api/signup.go index 09ac43524..7901f0caf 100644 --- a/internal/api/signup.go +++ b/internal/api/signup.go @@ -183,6 +183,9 @@ func (a *API) Signup(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } + if err := a.triggerBeforeUserCreated(r, db, signupUser); err != nil { + return err + } } err = db.Transaction(func(tx *storage.Connection) error { diff --git a/internal/api/token.go b/internal/api/token.go index 0f041faf5..8a86b3016 100644 --- a/internal/api/token.go +++ b/internal/api/token.go @@ -510,13 +510,17 @@ func validateTokenClaims(outputClaims map[string]interface{}) error { for _, desc := range result.Errors() { errorMessages += fmt.Sprintf("- %s\n", desc) - fmt.Printf("- %s\n", desc) } - return fmt.Errorf( + err = fmt.Errorf( "output claims do not conform to the expected schema: \n%s", errorMessages) - } - + if err != nil { + httpError := &apierrors.HTTPError{ + HTTPStatus: http.StatusInternalServerError, + Message: err.Error(), + } + return httpError + } return nil } diff --git a/internal/api/token_oidc.go b/internal/api/token_oidc.go index fbc4243ba..8a04bd78f 100644 --- a/internal/api/token_oidc.go +++ b/internal/api/token_oidc.go @@ -224,6 +224,10 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R grantParams.FillGrantParams(r) + if err := a.triggerBeforeUserCreatedExternal(r, db, userData, providerType); err != nil { + return err + } + if err := db.Transaction(func(tx *storage.Connection) error { var user *models.User var terr error diff --git a/internal/api/web3.go b/internal/api/web3.go index c9beeea1e..19d1d6841 100644 --- a/internal/api/web3.go +++ b/internal/api/web3.go @@ -102,8 +102,9 @@ func (a *API) web3GrantSolana(ctx context.Context, w http.ResponseWriter, r *htt return apierrors.NewOAuthError("invalid_grant", "Solana message was issued too far in the future") } + const providerType = "web3" providerId := strings.Join([]string{ - "web3", + providerType, params.Chain, parsedMessage.Address, }, ":") @@ -126,14 +127,18 @@ func (a *API) web3GrantSolana(ctx context.Context, w http.ResponseWriter, r *htt var grantParams models.GrantParams grantParams.FillGrantParams(r) + if err := a.triggerBeforeUserCreatedExternal(r, db, &userData, providerType); err != nil { + return err + } + err = db.Transaction(func(tx *storage.Connection) error { - user, terr := a.createAccountFromExternalIdentity(tx, r, &userData, "web3") + user, terr := a.createAccountFromExternalIdentity(tx, r, &userData, providerType) if terr != nil { return terr } if terr := models.NewAuditLogEntry(r, tx, user, models.LoginAction, "", map[string]interface{}{ - "provider": "web3", + "provider": providerType, "chain": params.Chain, "network": parsedMessage.ChainID, "address": parsedMessage.Address, diff --git a/internal/e2e/e2eapi/e2eapi.go b/internal/e2e/e2eapi/e2eapi.go index f28e89e51..551035432 100644 --- a/internal/e2e/e2eapi/e2eapi.go +++ b/internal/e2e/e2eapi/e2eapi.go @@ -9,7 +9,9 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" + jwt "github.com/golang-jwt/jwt/v5" "github.com/supabase/auth/internal/api" "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/conf" @@ -22,6 +24,8 @@ type Instance struct { Config *conf.GlobalConfiguration Conn *storage.Connection APIServer *httptest.Server + APIClient *http.Client + apiURL *url.URL closers []func() } @@ -30,9 +34,16 @@ func New(globalCfg *conf.GlobalConfiguration) (*Instance, error) { o := new(Instance) o.Config = globalCfg - conn, err := test.SetupDBConnection(globalCfg) + if err := o.init(); err != nil { + return nil, err + } + return o, nil +} + +func (o *Instance) init() error { + conn, err := test.SetupDBConnection(o.Config) if err != nil { - return nil, fmt.Errorf("error setting up db connection: %w", err) + return fmt.Errorf("error setting up db connection: %w", err) } o.addCloser(func() { if conn.Store != nil { @@ -46,12 +57,24 @@ func New(globalCfg *conf.GlobalConfiguration) (*Instance, error) { apiVer = "1" } - a := api.NewAPIWithVersion(globalCfg, conn, apiVer) + a := api.NewAPIWithVersion(o.Config, conn, apiVer) apiSrv := httptest.NewServer(a) o.addCloser(apiSrv) o.APIServer = apiSrv + o.APIClient = apiSrv.Client() - return o, nil + return o.initURL() +} + +func (o *Instance) initURL() error { + u, err := url.Parse(o.APIServer.URL) + if err != nil { + defer o.Close() + + return err + } + o.apiURL = u + return nil } func (o *Instance) Close() error { @@ -70,6 +93,48 @@ func (o *Instance) addCloser(v any) { } } +func (o *Instance) Do( + req *http.Request, +) (*http.Response, error) { + req.URL = o.apiURL.ResolveReference(req.URL) + + return o.APIClient.Do(req) +} + +func (o *Instance) DoAuth( + req *http.Request, + jwt string, +) (*http.Response, error) { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", jwt)) + return o.Do(req) +} + +func (o *Instance) DoAdmin( + req *http.Request, +) (*http.Response, error) { + return o.doAdmin(req, []byte(o.Config.JWT.Secret)) +} + +func (o *Instance) doAdmin( + req *http.Request, + key any, +) (*http.Response, error) { + tok := jwt.NewWithClaims( + jwt.SigningMethodHS256, + &api.AccessTokenClaims{ + Role: "supabase_admin", + }, + ) + + adminJwt, err := tok.SignedString(key) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", adminJwt)) + + return o.Do(req) +} + func Do( ctx context.Context, method string, diff --git a/internal/e2e/e2eapi/e2eapi_test.go b/internal/e2e/e2eapi/e2eapi_test.go index 391de9514..098d4c831 100644 --- a/internal/e2e/e2eapi/e2eapi_test.go +++ b/internal/e2e/e2eapi/e2eapi_test.go @@ -1,7 +1,9 @@ package e2eapi import ( + "bytes" "context" + "encoding/json" "errors" "fmt" "io" @@ -29,7 +31,7 @@ func TestInstance(t *testing.T) { require.NoError(t, err) defer inst.Close() - email := "e2etesthooks_" + uuid.Must(uuid.NewV4()).String() + "@localhost" + email := "e2eapitest_" + uuid.Must(uuid.NewV4()).String() + "@localhost" req := &api.SignupParams{ Email: email, Password: "password", @@ -40,6 +42,50 @@ func TestInstance(t *testing.T) { require.Equal(t, email, res.Email.String()) }) + t.Run("DoAdmin", func(t *testing.T) { + globalCfg := e2e.Must(e2e.Config()) + inst, err := New(globalCfg) + require.NoError(t, err) + defer inst.Close() + + email := "e2eapitest_" + uuid.Must(uuid.NewV4()).String() + "@localhost" + req := &api.InviteParams{ + Email: email, + } + res := new(models.User) + + body := new(bytes.Buffer) + err = json.NewEncoder(body).Encode(req) + require.NoError(t, err) + + httpReq, err := http.NewRequestWithContext( + ctx, "POST", "/invite", body) + require.NoError(t, err) + + httpRes, err := inst.DoAdmin(httpReq) + require.NoError(t, err) + + err = json.NewDecoder(httpRes.Body).Decode(res) + require.NoError(t, err) + require.Equal(t, email, res.Email.String()) + }) + + t.Run("DoAdminFailure", func(t *testing.T) { + globalCfg := e2e.Must(e2e.Config()) + inst, err := New(globalCfg) + require.NoError(t, err) + defer inst.Close() + + httpReq, err := http.NewRequestWithContext( + ctx, "POST", "/invite", nil) + require.NoError(t, err) + + httpRes, err := inst.doAdmin(httpReq, new(int)) + require.Error(t, err) + require.Nil(t, httpRes) + + }) + t.Run("Failure", func(t *testing.T) { globalCfg := e2e.Must(e2e.Config()) globalCfg.DB.Driver = "" @@ -49,6 +95,17 @@ func TestInstance(t *testing.T) { require.Error(t, err) require.Nil(t, inst) }) + + t.Run("InitURLFailure", func(t *testing.T) { + globalCfg := e2e.Must(e2e.Config()) + inst, err := New(globalCfg) + require.NoError(t, err) + defer inst.Close() + + inst.APIServer.URL = "\x01" + err = inst.initURL() + require.Error(t, err) + }) }) } @@ -172,7 +229,6 @@ func TestDo(t *testing.T) { err := Do(ctx, http.MethodPost, ts.URL, nil, nil) require.Error(t, err) require.Equal(t, sentinel, err) - }) } }) diff --git a/internal/e2e/e2ehooks/e2ehooks.go b/internal/e2e/e2ehooks/e2ehooks.go index 0bd0bd522..637f4f090 100644 --- a/internal/e2e/e2ehooks/e2ehooks.go +++ b/internal/e2e/e2ehooks/e2ehooks.go @@ -50,9 +50,16 @@ func New(globalCfg *conf.GlobalConfiguration) (*Instance, error) { } func HandleSuccess() http.Handler { + return HandleJSON(map[string]any{}) +} + +func HandleJSON(m map[string]any) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Add("content-type", "application/json") - _, _ = io.WriteString(w, "{}") + + if err := json.NewEncoder(w).Encode(&m); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } }) } @@ -68,7 +75,41 @@ func NewHook(name v0hooks.Name) *Hook { o := &Hook{ name: name, } - o.SetHandler(HandleSuccess()) + + //exhaustive:ignore + switch name { + case v0hooks.CustomizeAccessToken: + // This hooks returns the exact claims given. + hr := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("content-type", "application/json") + w.WriteHeader(http.StatusOK) + + var v any + if err := json.NewDecoder(r.Body).Decode(&v); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + if err := json.NewEncoder(w).Encode(&v); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + }) + o.SetHandler(hr) + + case v0hooks.MFAVerification: + hr := HandleJSON(map[string]any{ + "decision": "continue", + }) + o.SetHandler(hr) + + case v0hooks.PasswordVerification: + hr := HandleJSON(map[string]any{ + "decision": "continue", + }) + o.SetHandler(hr) + + default: + o.SetHandler(HandleSuccess()) + } + return o } diff --git a/internal/hooks/hookshttp/hookshttp.go b/internal/hooks/hookshttp/hookshttp.go index e8c39397b..e11897e85 100644 --- a/internal/hooks/hookshttp/hookshttp.go +++ b/internal/hooks/hookshttp/hookshttp.go @@ -92,6 +92,10 @@ func (o *Dispatcher) Dispatch( } if data != nil { if err := json.Unmarshal(data, res); err != nil { + e := new(apierrors.HTTPError) + if errors.As(err, &e) { + return e + } return apierrors.NewInternalServerError( "Error unmarshaling JSON output.").WithInternalError(err) } diff --git a/internal/hooks/hookspgfunc/hookspgfunc.go b/internal/hooks/hookspgfunc/hookspgfunc.go index 5d26a006d..ce8ee2a32 100644 --- a/internal/hooks/hookspgfunc/hookspgfunc.go +++ b/internal/hooks/hookspgfunc/hookspgfunc.go @@ -3,6 +3,7 @@ package hookspgfunc import ( "context" "encoding/json" + "errors" "fmt" "time" @@ -58,6 +59,10 @@ func (o *Dispatcher) Dispatch( } if data != nil { if err := json.Unmarshal(data, res); err != nil { + e := new(apierrors.HTTPError) + if errors.As(err, &e) { + return e + } return apierrors.NewInternalServerError( "Error unmarshaling JSON output.").WithInternalError(err) } diff --git a/internal/hooks/v0hooks/v0hooks.go b/internal/hooks/v0hooks/v0hooks.go index 33a054ec1..b0acb4821 100644 --- a/internal/hooks/v0hooks/v0hooks.go +++ b/internal/hooks/v0hooks/v0hooks.go @@ -1,11 +1,13 @@ package v0hooks import ( + "encoding/json" "net/http" "time" "github.com/gofrs/uuid" "github.com/golang-jwt/jwt/v5" + "github.com/supabase/auth/internal/api/apierrors" "github.com/supabase/auth/internal/mailer" "github.com/supabase/auth/internal/models" "github.com/supabase/auth/internal/utilities" @@ -139,7 +141,37 @@ type CustomAccessTokenInput struct { } type CustomAccessTokenOutput struct { - Claims map[string]interface{} `json:"claims"` + Claims map[string]any `json:"claims"` +} + +func (o *CustomAccessTokenOutput) UnmarshalJSON(b []byte) error { + var m map[string]any + if err := json.Unmarshal(b, &m); err != nil { + return err + } + + // First check if the claims field is missing + if _, ok := m["claims"]; !ok { + httpError := &apierrors.HTTPError{ + HTTPStatus: http.StatusInternalServerError, + Message: "output claims field is missing", + } + return httpError + } + + // This check allows us to skip an additional unmarshal for valid inputs + if v, ok := m["claims"].(map[string]any); ok { + o.Claims = v + return nil + } + + // The Claims field is not a map[string]any so we unmarshal again just + // to get the correct error type. + type raw CustomAccessTokenOutput + if err := json.Unmarshal(b, (*raw)(o)); err != nil { + return err + } + return nil } type SendSMSInput struct {