From 4fdd5ba4def93f5926c628dcd04d7496e5f6811d Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Wed, 17 Jan 2024 08:42:15 +0100 Subject: [PATCH] feat: add error codes --- internal/api/admin.go | 36 +++++----- internal/api/audit.go | 4 +- internal/api/auth.go | 20 +++--- internal/api/errors.go | 16 ++--- internal/api/external.go | 52 +++++++------- internal/api/external_oauth.go | 11 +-- internal/api/hooks.go | 2 +- internal/api/identity.go | 16 ++--- internal/api/invite.go | 6 +- internal/api/logout.go | 2 +- internal/api/magic_link.go | 17 +++-- internal/api/mail.go | 123 +++++++++++++++++++++++---------- internal/api/mfa.go | 24 +++---- internal/api/middleware.go | 11 ++- internal/api/otp.go | 28 ++++---- internal/api/phone.go | 2 +- internal/api/pkce.go | 8 +-- internal/api/reauthenticate.go | 23 +++--- internal/api/recover.go | 8 +-- internal/api/resend.go | 25 ++++--- internal/models/flow_state.go | 4 +- 21 files changed, 254 insertions(+), 184 deletions(-) diff --git a/internal/api/admin.go b/internal/api/admin.go index d6fd17dac7..c113a0c9ad 100644 --- a/internal/api/admin.go +++ b/internal/api/admin.go @@ -50,7 +50,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, userID, err := uuid.FromString(chi.URLParam(r, "user_id")) if err != nil { - return nil, badRequestError("user_id must be an UUID") + return nil, badRequestError("validation_failed", "user_id must be an UUID") } observability.LogEntrySetField(r, "user_id", userID) @@ -58,7 +58,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, u, err := models.FindUserByID(db, userID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("User not found") + return nil, notFoundError("not_found", "User not found") } return nil, internalServerError("Database error loading user").WithInternalError(err) } @@ -69,7 +69,7 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context, func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) { factorID, err := uuid.FromString(chi.URLParam(r, "factor_id")) if err != nil { - return nil, badRequestError("factor_id must be an UUID") + return nil, badRequestError("validation_failed", "factor_id must be an UUID") } observability.LogEntrySetField(r, "factor_id", factorID) @@ -77,7 +77,7 @@ func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Contex f, err := models.FindFactorByFactorID(a.db, factorID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("Factor not found") + return nil, notFoundError("not_found", "Factor not found") } return nil, internalServerError("Database error loading factor").WithInternalError(err) } @@ -89,11 +89,11 @@ func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) { body, err := getBodyBytes(r) if err != nil { - return nil, badRequestError("Could not read body").WithInternalError(err) + return nil, internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, ¶ms); err != nil { - return nil, badRequestError("Could not decode admin user params: %v", err) + return nil, badRequestError("bad_json", "Could not decode admin user params").WithInternalError(err) } return ¶ms, nil @@ -107,12 +107,12 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error { pageParams, err := paginate(r) if err != nil { - return badRequestError("Bad Pagination Parameters: %v", err) + return badRequestError("validation_failed", "Bad Pagination Parameters: %v", err).WithInternalError(err) } sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}}) if err != nil { - return badRequestError("Bad Sort Parameters: %v", err) + return badRequestError("validation_failed", "Bad Sort Parameters: %v", err) } filter := r.URL.Query().Get("filter") @@ -166,7 +166,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error { if params.BanDuration != "none" { duration, err = time.ParseDuration(params.BanDuration) if err != nil { - return badRequestError("invalid format for ban duration: %v", err) + return badRequestError("validation_failed", "invalid format for ban duration: %v", err) } } if terr := user.Ban(a.db, duration); terr != nil { @@ -314,7 +314,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { } if params.Email == "" && params.Phone == "" { - return unprocessableEntityError("Cannot create a user without either an email or phone") + return badRequestError("validation_failed", "Cannot create a user without either an email or phone") } var providers []string @@ -326,7 +326,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil { return internalServerError("Database error checking email").WithInternalError(err) } else if user != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError("email_exists", DuplicateEmailMsg) } providers = append(providers, "email") } @@ -339,7 +339,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil { return internalServerError("Database error checking phone").WithInternalError(err) } else if exists { - return unprocessableEntityError("Phone number already registered by another user") + return unprocessableEntityError("phone_exists", "Phone number already registered by another user") } providers = append(providers, "phone") } @@ -435,7 +435,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error { if params.BanDuration != "none" { duration, err = time.ParseDuration(params.BanDuration) if err != nil { - return badRequestError("invalid format for ban duration: %v", err) + return badRequestError("validation_failed", "invalid format for ban duration: %v", err) } } if terr := user.Ban(a.db, duration); terr != nil { @@ -466,11 +466,11 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error { params := &adminUserDeleteParams{} body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if len(body) > 0 { if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read params: %v", err) + return badRequestError("bad_json", "Could not read params: %v", err) } } else { params.ShouldSoftDelete = false @@ -567,11 +567,11 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro params := &adminUserUpdateFactorParams{} body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read factor update params: %v", err) + return badRequestError("bad_json", "Could not read factor update params: %v", err).WithInternalError(err) } err = a.db.Transaction(func(tx *storage.Connection) error { @@ -582,7 +582,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro } if params.FactorType != "" { if params.FactorType != models.TOTP { - return badRequestError("Factor Type not valid") + return badRequestError("validation_failed", "Factor Type not valid") } if terr := factor.UpdateFactorType(tx, params.FactorType); terr != nil { return terr diff --git a/internal/api/audit.go b/internal/api/audit.go index 2cb99c6e75..af305d71f4 100644 --- a/internal/api/audit.go +++ b/internal/api/audit.go @@ -20,7 +20,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { // aud := a.requestAud(ctx, r) pageParams, err := paginate(r) if err != nil { - return badRequestError("Bad Pagination Parameters: %v", err) + return badRequestError("validation_failed", "Bad Pagination Parameters: %v", err) } var col []string @@ -31,7 +31,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error { qparts := strings.SplitN(q, ":", 2) col, exists = filterColumnMap[qparts[0]] if !exists || len(qparts) < 2 { - return badRequestError("Invalid query scope: %s", q) + return badRequestError("validation_failed", "Invalid query scope: %s", q) } qval = qparts[1] } diff --git a/internal/api/auth.go b/internal/api/auth.go index c1f43d5114..04fbf3be1b 100644 --- a/internal/api/auth.go +++ b/internal/api/auth.go @@ -40,7 +40,7 @@ func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.R claims := getClaims(ctx) if claims == nil { fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "Invalid token") - return nil, unauthorizedError("Invalid token") + return nil, unauthorizedError("bad_jwt", "Invalid token") } adminRoles := a.config.JWT.AdminRoles @@ -51,14 +51,14 @@ func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.R } fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "this token needs role 'supabase_admin' or 'service_role'") - return nil, unauthorizedError("User not allowed") + return nil, unauthorizedError("not_admin", "User not allowed") } func (a *API) extractBearerToken(r *http.Request) (string, error) { authHeader := r.Header.Get("Authorization") matches := bearerRegexp.FindStringSubmatch(authHeader) if len(matches) != 2 { - return "", unauthorizedError("This endpoint requires a Bearer token") + return "", unauthorizedError("no_authorization", "This endpoint requires a Bearer token") } return matches[1], nil @@ -73,7 +73,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e return []byte(config.JWT.Secret), nil }) if err != nil { - return nil, unauthorizedError("invalid JWT: unable to parse or verify signature, %v", err) + return nil, unauthorizedError("bad_jwt", "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err) } return withToken(ctx, token), nil @@ -84,23 +84,23 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro claims := getClaims(ctx) if claims == nil { - return ctx, unauthorizedError("invalid token: missing claims") + return ctx, unauthorizedError("bad_jwt", "invalid token: missing claims") } if claims.Subject == "" { - return nil, unauthorizedError("invalid claim: missing sub claim") + return nil, unauthorizedError("bad_jwt", "invalid claim: missing sub claim") } var user *models.User if claims.Subject != "" { userId, err := uuid.FromString(claims.Subject) if err != nil { - return ctx, badRequestError("invalid claim: sub claim must be a UUID").WithInternalError(err) + return ctx, badRequestError("bad_jwt", "invalid claim: sub claim must be a UUID").WithInternalError(err) } user, err = models.FindUserByID(db, userId) if err != nil { if models.IsNotFoundError(err) { - return ctx, notFoundError(err.Error()) + return ctx, unauthorizedError("user_not_found", "User from sub claim in JWT does not exist") } return ctx, err } @@ -111,11 +111,11 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() { sessionId, err := uuid.FromString(claims.SessionId) if err != nil { - return ctx, err + return ctx, badRequestError("bad_jwt", "invalid claim: session_id claim must be a UUID").WithInternalError(err) } session, err = models.FindSessionByID(db, sessionId, false) if err != nil && !models.IsNotFoundError(err) { - return ctx, err + return ctx, unauthorizedError("session_not_found", "Session from session_id claim in JWT does not exist") } ctx = withSession(ctx, session) } diff --git a/internal/api/errors.go b/internal/api/errors.go index 56f404e3c9..bbce0ebf2c 100644 --- a/internal/api/errors.go +++ b/internal/api/errors.go @@ -84,7 +84,7 @@ func oauthError(err string, description string) *OAuthError { return &OAuthError{Err: err, Description: description} } -func badRequestError(fmtString string, args ...interface{}) *HTTPError { +func badRequestError(reason, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusBadRequest, fmtString, args...) } @@ -92,31 +92,31 @@ func internalServerError(fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusInternalServerError, fmtString, args...) } -func notFoundError(fmtString string, args ...interface{}) *HTTPError { +func notFoundError(reason, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusNotFound, fmtString, args...) } -func expiredTokenError(fmtString string, args ...interface{}) *HTTPError { +func expiredTokenError(reason, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusUnauthorized, fmtString, args...) } -func unauthorizedError(fmtString string, args ...interface{}) *HTTPError { +func unauthorizedError(reason, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusUnauthorized, fmtString, args...) } -func forbiddenError(fmtString string, args ...interface{}) *HTTPError { +func forbiddenError(reason, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusForbidden, fmtString, args...) } -func unprocessableEntityError(fmtString string, args ...interface{}) *HTTPError { +func unprocessableEntityError(reason, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusUnprocessableEntity, fmtString, args...) } -func tooManyRequestsError(fmtString string, args ...interface{}) *HTTPError { +func tooManyRequestsError(reason, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusTooManyRequests, fmtString, args...) } -func conflictError(fmtString string, args ...interface{}) *HTTPError { +func conflictError(reason, fmtString string, args ...interface{}) *HTTPError { return httpError(http.StatusConflict, fmtString, args...) } diff --git a/internal/api/external.go b/internal/api/external.go index 9ccfdb196f..bd8decea67 100644 --- a/internal/api/external.go +++ b/internal/api/external.go @@ -56,7 +56,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ p, err := a.Provider(ctx, providerType, scopes) if err != nil { - return "", badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return "", badRequestError("validation_failed", "Unsupported provider: %+v", err).WithInternalError(err) } inviteToken := query.Get("invite_token") @@ -64,7 +64,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ _, userErr := models.FindUserByConfirmationToken(db, inviteToken) if userErr != nil { if models.IsNotFoundError(userErr) { - return "", notFoundError(userErr.Error()) + return "", notFoundError("user_not_found", "User identified by token not found") } return "", internalServerError("Database error finding user").WithInternalError(userErr) } @@ -82,14 +82,12 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ if flowType == models.PKCEFlow { codeChallengeMethodType, err := models.ParseCodeChallengeMethod(codeChallengeMethod) if err != nil { - return "", err - } - flowState, err := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth) - if err != nil { - return "", err + return "", badRequestError("validation_failed", "Code challenge not valid").WithInternalError(err) } + flowState := models.NewFlowState(providerType, codeChallenge, codeChallengeMethodType, models.OAuth) + if err := a.db.Create(flowState); err != nil { - return "", err + return "", internalServerError("Failed to create flow state").WithInternalError(err) } flowStateID = flowState.ID.String() } @@ -137,7 +135,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ switch externalProvider := p.(type) { case *provider.TwitterProvider: if err := storage.StoreInSession(providerType, externalProvider.Marshal(), r, w); err != nil { - return "", internalServerError("Error storing request token in session").WithInternalError(err) + return "", internalServerError("Error storing request token in cookies").WithInternalError(err) } } @@ -210,9 +208,12 @@ func (a *API) internalExternalProviderCallback(w http.ResponseWriter, r *http.Re // if there's a non-empty FlowStateID we perform PKCE Flow if flowStateID := getFlowStateID(ctx); flowStateID != "" { flowState, err = models.FindFlowStateByID(a.db, flowStateID) - if err != nil { - return err + if models.IsNotFoundError(err) { + return notFoundError("flow_state_not_found", "Flow state not found").WithInternalError(err) + } else if err != nil { + return internalServerError("Failed to find flow state").WithInternalError(err) } + } var user *models.User @@ -311,7 +312,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. case models.CreateAccount: if config.DisableSignup { - return nil, forbiddenError("Signups not allowed for this instance") + return nil, forbiddenError("signup_disabled", "Signups not allowed for this instance") } params := &SignupParams{ @@ -358,14 +359,14 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } case models.MultipleAccounts: - return nil, internalServerError(fmt.Sprintf("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain)) + return nil, internalServerError("Multiple accounts with the same email address in the same linking domain detected: %v", decision.LinkingDomain) default: - return nil, internalServerError(fmt.Sprintf("Unknown automatic linking decision: %v", decision.Decision)) + return nil, internalServerError("Unknown automatic linking decision: %v", decision.Decision) } if user.IsBanned() { - return nil, unauthorizedError("User is unauthorized") + return nil, unauthorizedError("user_banned", "User is banned") } if !user.IsConfirmed() { @@ -398,7 +399,7 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. externalURL := getExternalHost(ctx) if terr = sendConfirmation(tx, user, mailer, config.SMTP.MaxFrequency, referrer, externalURL, config.Mailer.OtpLength, models.ImplicitFlow); terr != nil { if errors.Is(terr, MaxFrequencyLimitError) { - return nil, tooManyRequestsError("For security purposes, you can only request this once every minute") + return nil, tooManyRequestsError("over_email_send_rate", "For security purposes, you can only request this once every minute") } return nil, internalServerError("Error sending confirmation mail").WithInternalError(terr) } @@ -406,9 +407,9 @@ func (a *API) createAccountFromExternalIdentity(tx *storage.Connection, r *http. } if !config.Mailer.AllowUnverifiedEmailSignIns { if emailConfirmationSent { - return nil, storage.NewCommitWithError(unauthorizedError(fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) + return nil, storage.NewCommitWithError(unauthorizedError("provider_email_needs_verification", fmt.Sprintf("Unverified email with %v. A confirmation email has been sent to your %v email", providerType, providerType))) } - return nil, storage.NewCommitWithError(unauthorizedError(fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) + return nil, storage.NewCommitWithError(unauthorizedError("provider_email_needs_verification", fmt.Sprintf("Unverified email with %v. Verify the email with %v in order to sign in", providerType, providerType))) } } } else { @@ -430,7 +431,7 @@ func (a *API) processInvite(r *http.Request, ctx context.Context, tx *storage.Co user, err := models.FindUserByConfirmationToken(tx, inviteToken) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError(err.Error()) + return nil, notFoundError("invite_not_found", "Invite not found") } return nil, internalServerError("Database error finding user").WithInternalError(err) } @@ -446,7 +447,7 @@ func (a *API) processInvite(r *http.Request, ctx context.Context, tx *storage.Co } if emailData == nil { - return nil, badRequestError("Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) + return nil, badRequestError("validation_failed", "Invited email does not match emails from external provider").WithInternalMessage("invited=%s external=%s", user.Email, strings.Join(emails, ", ")) } var identityData map[string]interface{} @@ -502,8 +503,11 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont _, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) { return []byte(config.JWT.Secret), nil }) - if err != nil || claims.Provider == "" { - return nil, badRequestError("OAuth state is invalid: %v", err) + if err != nil { + return nil, badRequestError("bad_oauth_state", "OAuth callback with invalid state").WithInternalError(err) + } + if claims.Provider == "" { + return nil, badRequestError("bad_oauth_state", "OAuth callback with invalid state (missing provider)") } if claims.InviteToken != "" { ctx = withInviteToken(ctx, claims.InviteToken) @@ -517,12 +521,12 @@ func (a *API) loadExternalState(ctx context.Context, state string) (context.Cont if claims.LinkingTargetID != "" { linkingTargetUserID, err := uuid.FromString(claims.LinkingTargetID) if err != nil { - return nil, badRequestError("invalid target user id") + return nil, badRequestError("bad_oauth_state", "OAuth callback with invalid state (linking_target_id must be UUID)") } u, err := models.FindUserByID(a.db, linkingTargetUserID) if err != nil { if models.IsNotFoundError(err) { - return nil, notFoundError("Linking target user not found") + return nil, notFoundError("user_not_found", "Linking target user not found") } return nil, internalServerError("Database error loading user").WithInternalError(err) } diff --git a/internal/api/external_oauth.go b/internal/api/external_oauth.go index 2fd93d22ee..910745b931 100644 --- a/internal/api/external_oauth.go +++ b/internal/api/external_oauth.go @@ -2,6 +2,7 @@ package api import ( "context" + "fmt" "net/http" "net/url" @@ -31,7 +32,7 @@ func (a *API) loadFlowState(w http.ResponseWriter, r *http.Request) (context.Con } if state == "" { - return nil, badRequestError("OAuth state parameter missing") + return nil, badRequestError("bad_oauth_callback", "OAuth state parameter missing") } ctx := r.Context() @@ -61,12 +62,12 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s oauthCode := rq.Get("code") if oauthCode == "" { - return nil, badRequestError("Authorization code missing") + return nil, badRequestError("bad_oauth_callback", "OAuth callback with missing authorization code missing") } oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return nil, badRequestError("oauth_provider_not_supported", "Unsupported provider: %+v", err).WithInternalError(err) } log := observability.GetLogEntry(r) @@ -108,7 +109,7 @@ func (a *API) oAuthCallback(ctx context.Context, r *http.Request, providerType s func (a *API) oAuth1Callback(ctx context.Context, r *http.Request, providerType string) (*OAuthProviderData, error) { oAuthProvider, err := a.OAuthProvider(ctx, providerType) if err != nil { - return nil, badRequestError("Unsupported provider: %+v", err).WithInternalError(err) + return nil, badRequestError("oauth_provider_not_supported", "Unsupported provider: %+v", err).WithInternalError(err) } value, err := storage.GetFromSession(providerType, r) if err != nil { @@ -156,6 +157,6 @@ func (a *API) OAuthProvider(ctx context.Context, name string) (provider.OAuthPro case provider.OAuthProvider: return p, nil default: - return nil, badRequestError("Provider can not be used for OAuth") + return nil, fmt.Errorf("Provider %v cannot be used for OAuth", name) } } diff --git a/internal/api/hooks.go b/internal/api/hooks.go index 2fd31983e2..447dadc9b2 100644 --- a/internal/api/hooks.go +++ b/internal/api/hooks.go @@ -135,7 +135,7 @@ func (w *Webhook) trigger() (io.ReadCloser, error) { } hooklog.Infof("Failed to process webhook for %s after %d attempts", w.URL, w.Retries) - return nil, unprocessableEntityError("Failed to handle signup webhook") + return nil, internalServerError("Failed to handle signup webhook") } func (w *Webhook) generateSignature() (string, error) { diff --git a/internal/api/identity.go b/internal/api/identity.go index 14f2c167d9..4936e87f7f 100644 --- a/internal/api/identity.go +++ b/internal/api/identity.go @@ -17,22 +17,22 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { claims := getClaims(ctx) if claims == nil { - return badRequestError("Could not read claims") + return internalServerError("Could not read claims") } aud := a.requestAud(ctx, r) if aud != claims.Audience { - return badRequestError("Token audience doesn't match request audience") + return unauthorizedError("unexpected_audience", "Token audience doesn't match request audience") } identityID, err := uuid.FromString(chi.URLParam(r, "identity_id")) if err != nil { - return badRequestError("identity_id must be an UUID") + return badRequestError("validation_failed", "identity_id must be an UUID") } user := getUser(ctx) if len(user.Identities) <= 1 { - return badRequestError("User must have at least 1 identity after unlinking") + return unprocessableEntityError("last_identity_not_deletable", "User must have at least 1 identity after unlinking") } var identityToBeDeleted *models.Identity for i := range user.Identities { @@ -43,7 +43,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { } } if identityToBeDeleted == nil { - return badRequestError("Identity doesn't exist") + return notFoundError("identity_not_found", "Identity doesn't exist") } err = a.db.Transaction(func(tx *storage.Connection) error { @@ -59,7 +59,7 @@ func (a *API) DeleteIdentity(w http.ResponseWriter, r *http.Request) error { } if terr := user.UpdateUserEmail(tx); terr != nil { if models.IsUniqueConstraintViolatedError(terr) { - return forbiddenError("Unable to unlink identity due to email conflict").WithInternalError(terr) + return unprocessableEntityError("email_conflict_identity_not_deletable", "Unable to unlink identity due to email conflict").WithInternalError(terr) } return internalServerError("Database error updating user email").WithInternalError(terr) } @@ -102,9 +102,9 @@ func (a *API) linkIdentityToUser(ctx context.Context, tx *storage.Connection, us } if identity != nil { if identity.UserID == targetUser.ID { - return nil, badRequestError("Identity is already linked") + return nil, unprocessableEntityError("identity_already_exists", "Identity is already linked") } - return nil, badRequestError("Identity is already linked to another user") + return nil, unprocessableEntityError("identity_already_exists", "Identity is already linked to another user") } if _, terr := a.createNewIdentity(tx, targetUser, providerType, structs.Map(userData.Metadata)); terr != nil { return nil, terr diff --git a/internal/api/invite.go b/internal/api/invite.go index 2a0aeb51d9..212e7cefb5 100644 --- a/internal/api/invite.go +++ b/internal/api/invite.go @@ -27,11 +27,11 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read Invite params: %v", err) + return badRequestError("bad_json", "Could not read Invite params: %v", err).WithInternalError(err) } params.Email, err = validateEmail(params.Email) @@ -48,7 +48,7 @@ func (a *API) Invite(w http.ResponseWriter, r *http.Request) error { err = db.Transaction(func(tx *storage.Connection) error { if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError("email_exists", DuplicateEmailMsg) } } else { signupParams := SignupParams{ diff --git a/internal/api/logout.go b/internal/api/logout.go index ad95b22a49..2aae00a28b 100644 --- a/internal/api/logout.go +++ b/internal/api/logout.go @@ -36,7 +36,7 @@ func (a *API) Logout(w http.ResponseWriter, r *http.Request) error { scope = LogoutOthers default: - return badRequestError(fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) + return badRequestError("validation_failed", fmt.Sprintf("Unsupported logout scope %q", r.URL.Query().Get("scope"))) } } diff --git a/internal/api/magic_link.go b/internal/api/magic_link.go index e1b12caafd..b807d8ab6d 100644 --- a/internal/api/magic_link.go +++ b/internal/api/magic_link.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "net/http" "strings" @@ -24,7 +25,7 @@ type MagicLinkParams struct { func (p *MagicLinkParams) Validate() error { if p.Email == "" { - return unprocessableEntityError("Password recovery requires an email") + return unprocessableEntityError("validation_failed", "Password recovery requires an email") } var err error p.Email, err = validateEmail(p.Email) @@ -44,14 +45,14 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { config := a.config if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return unprocessableEntityError("email_provider_disabled", "Email logins are disabled") } params := &MagicLinkParams{} jsonDecoder := json.NewDecoder(r.Body) err := jsonDecoder.Decode(params) if err != nil { - return badRequestError("Could not read verification params: %v", err) + return badRequestError("bad_json", "Could not read verification params: %v", err).WithInternalError(err) } if err := params.Validate(); err != nil { @@ -82,7 +83,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { // Sign them up with temporary password. password, err := password.Generate(64, 10, 1, false, true) if err != nil { - internalServerError("error creating user").WithInternalError(err) + return internalServerError("error creating user").WithInternalError(err) } signUpParams := &SignupParams{ @@ -94,7 +95,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must always be marshallable + panic(fmt.Errorf("Failed to marshal SignupParams: %w", err)) } r.Body = io.NopCloser(strings.NewReader(string(newBodyContent))) r.ContentLength = int64(len(string(newBodyContent))) @@ -113,7 +115,8 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { } metadata, err := json.Marshal(newBodyContent) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must always be marshallable + panic(fmt.Errorf("Failed to marshal SignupParams: %w", err)) } r.Body = io.NopCloser(bytes.NewReader(metadata)) return a.MagicLink(w, r) @@ -148,7 +151,7 @@ func (a *API) MagicLink(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError("over_email_send_rate", "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Error sending magic link").WithInternalError(err) } diff --git a/internal/api/mail.go b/internal/api/mail.go index 6d4bd0817c..99fce69f81 100644 --- a/internal/api/mail.go +++ b/internal/api/mail.go @@ -52,11 +52,11 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not parse JSON: %v", err) + return badRequestError("bad_json", "Could not parse JSON: %v", err).WithInternalError(err) } params.Email, err = validateEmail(params.Email) @@ -72,14 +72,17 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { user, err := models.FindUserByEmailAndAudience(db, params.Email, aud) if err != nil { if models.IsNotFoundError(err) { - if params.Type == magicLinkVerification { + switch params.Type { + case magicLinkVerification: params.Type = signupVerification params.Password, err = password.Generate(64, 10, 1, false, true) if err != nil { - return internalServerError("error creating user").WithInternalError(err) + // password generation must always succeed + panic(err) } - } else if params.Type == recoveryVerification || params.Type == "email_change_current" || params.Type == "email_change_new" { - return notFoundError(err.Error()) + + default: + return notFoundError("user_not_found", "User with this email not found") } } else { return internalServerError("Database error finding user").WithInternalError(err) @@ -90,7 +93,8 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { now := time.Now() otp, err := crypto.GenerateOtp(config.Mailer.OtpLength) if err != nil { - return err + // OTP generation must always succeed + panic(err) } hashedToken := crypto.GenerateTokenHash(params.Email, otp) @@ -124,11 +128,14 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } user.RecoveryToken = hashedToken user.RecoverySentAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + terr = tx.UpdateOnly(user, "recovery_token", "recovery_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for recovery") + } case inviteVerification: if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError("email_exists", DuplicateEmailMsg) } } else { signupParams := &SignupParams{ @@ -168,11 +175,14 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { user.ConfirmationToken = hashedToken user.ConfirmationSentAt = &now user.InvitedAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at", "invited_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for invite") + } case signupVerification: if user != nil { if user.IsConfirmed() { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError("email_exists", DuplicateEmailMsg) } if err := user.UpdateUserMetaData(tx, params.Data); err != nil { return internalServerError("Database error updating user").WithInternalError(err) @@ -197,19 +207,22 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } user.ConfirmationToken = hashedToken user.ConfirmationSentAt = &now - terr = errors.Wrap(tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") + terr = tx.UpdateOnly(user, "confirmation_token", "confirmation_sent_at") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for confirmation") + } case "email_change_current", "email_change_new": if !config.Mailer.SecureEmailChangeEnabled && params.Type == "email_change_current" { - return unprocessableEntityError("Enable secure email change to generate link for current email") + return badRequestError("validation_failed", "Enable secure email change to generate link for current email") } params.NewEmail, terr = validateEmail(params.NewEmail) if terr != nil { - return unprocessableEntityError("The new email address provided is invalid") + return terr } if duplicateUser, terr := models.IsDuplicatedEmail(tx, params.NewEmail, user.Aud, user); terr != nil { return internalServerError("Database error checking email").WithInternalError(terr) } else if duplicateUser != nil { - return unprocessableEntityError(DuplicateEmailMsg) + return unprocessableEntityError("email_exists", DuplicateEmailMsg) } now := time.Now() user.EmailChangeSentAt = &now @@ -220,9 +233,12 @@ func (a *API) adminGenerateLink(w http.ResponseWriter, r *http.Request) error { } else if params.Type == "email_change_new" { user.EmailChangeTokenNew = crypto.GenerateTokenHash(params.NewEmail, otp) } - terr = errors.Wrap(tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status"), "Database error updating user for email change") + terr = tx.UpdateOnly(user, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status") + if terr != nil { + terr = errors.Wrap(terr, "Database error updating user for email change") + } default: - return badRequestError("Invalid email action link type requested: %v", params.Type) + return badRequestError("validation_failed", "Invalid email action link type requested: %v", params.Type) } if terr != nil { @@ -261,7 +277,8 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail oldToken := u.ConfirmationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.ConfirmationToken = addFlowPrefixToToken(token, flowType) @@ -271,7 +288,12 @@ func sendConfirmation(tx *storage.Connection, u *models.User, mailer mailer.Mail return errors.Wrap(err, "Error sending confirmation email") } u.ConfirmationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at"), "Database error updating user for confirmation") + err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for confirmation") + } + + return nil } func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, referrerURL string, externalURL *url.URL, otpLength int) error { @@ -279,7 +301,8 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re oldToken := u.ConfirmationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.ConfirmationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) now := time.Now() @@ -289,7 +312,12 @@ func sendInvite(tx *storage.Connection, u *models.User, mailer mailer.Mailer, re } u.InvitedAt = &now u.ConfirmationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at"), "Database error updating user for invite") + err = tx.UpdateOnly(u, "confirmation_token", "confirmation_sent_at", "invited_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for invite") + } + + return nil } func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { @@ -301,7 +329,8 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile oldToken := u.RecoveryToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) @@ -311,7 +340,12 @@ func (a *API) sendPasswordRecovery(tx *storage.Connection, u *models.User, maile return errors.Wrap(err, "Error sending recovery email") } u.RecoverySentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + err = tx.UpdateOnly(u, "recovery_token", "recovery_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for recovery") + } + + return nil } func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, otpLength int) error { @@ -323,19 +357,22 @@ func (a *API) sendReauthenticationOtp(tx *storage.Connection, u *models.User, ma oldToken := u.ReauthenticationToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.ReauthenticationToken = crypto.GenerateTokenHash(u.GetEmail(), otp) - if err != nil { - return err - } now := time.Now() if err := mailer.ReauthenticateMail(u, otp); err != nil { u.ReauthenticationToken = oldToken return errors.Wrap(err, "Error sending reauthentication email") } u.ReauthenticationSentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at"), "Database error updating user for reauthentication") + err = tx.UpdateOnly(u, "reauthentication_token", "reauthentication_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for reauthentication") + } + + return nil } func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer mailer.Mailer, maxFrequency time.Duration, referrerURL string, externalURL *url.URL, otpLength int, flowType models.FlowType) error { @@ -348,7 +385,8 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile oldToken := u.RecoveryToken otp, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } token := crypto.GenerateTokenHash(u.GetEmail(), otp) u.RecoveryToken = addFlowPrefixToToken(token, flowType) @@ -359,7 +397,12 @@ func (a *API) sendMagicLink(tx *storage.Connection, u *models.User, mailer maile return errors.Wrap(err, "Error sending magic link email") } u.RecoverySentAt = &now - return errors.Wrap(tx.UpdateOnly(u, "recovery_token", "recovery_sent_at"), "Database error updating user for recovery") + err = tx.UpdateOnly(u, "recovery_token", "recovery_sent_at") + if err != nil { + return errors.Wrap(err, "Database error updating user for recovery") + } + + return nil } // sendEmailChange sends out an email change token to the new email. @@ -370,7 +413,8 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu } otpNew, err := crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } u.EmailChange = email token := crypto.GenerateTokenHash(u.EmailChange, otpNew) @@ -380,7 +424,8 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu if config.Mailer.SecureEmailChangeEnabled && u.GetEmail() != "" { otpCurrent, err = crypto.GenerateOtp(otpLength) if err != nil { - return err + // OTP generation must succeed + panic(err) } currentToken := crypto.GenerateTokenHash(u.GetEmail(), otpCurrent) u.EmailChangeTokenCurrent = addFlowPrefixToToken(currentToken, flowType) @@ -396,22 +441,28 @@ func (a *API) sendEmailChange(tx *storage.Connection, config *conf.GlobalConfigu } u.EmailChangeSentAt = &now - return errors.Wrap(tx.UpdateOnly( + err = tx.UpdateOnly( u, "email_change_token_current", "email_change_token_new", "email_change", "email_change_sent_at", "email_change_confirm_status", - ), "Database error updating user for email change") + ) + + if err != nil { + return errors.Wrap(err, "Database error updating user for email change") + } + + return nil } func validateEmail(email string) (string, error) { if email == "" { - return "", unprocessableEntityError("An email address is required") + return "", badRequestError("validation_failed", "An email address is required") } if err := checkmail.ValidateFormat(email); err != nil { - return "", unprocessableEntityError("Unable to validate email address: " + err.Error()) + return "", badRequestError("validation_failed", "Unable to validate email address: "+err.Error()) } return strings.ToLower(email), nil } diff --git a/internal/api/mfa.go b/internal/api/mfa.go index 03fb00e6be..2c23a323e6 100644 --- a/internal/api/mfa.go +++ b/internal/api/mfa.go @@ -72,7 +72,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + return badRequestError("bad_json", "invalid body: unable to parse JSON").WithInternalError(err) } if user.IsSSOUser { @@ -80,7 +80,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { } if params.FactorType != models.TOTP { - return badRequestError("factor_type needs to be totp") + return badRequestError("validation_failed", "factor_type needs to be totp") } if params.Issuer == "" { @@ -100,7 +100,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { } if len(factors) >= int(config.MFA.MaxEnrolledFactors) { - return forbiddenError("Enrolled factors exceed allowed limit, unenroll to continue") + return unprocessableEntityError("too_many_enrolled_mfa_factors", "Enrolled factors exceed allowed limit, unenroll to continue") } numVerifiedFactors := 0 @@ -111,7 +111,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { } if numVerifiedFactors >= config.MFA.MaxVerifiedFactors { - return forbiddenError("Maximum number of enrolled factors reached, unenroll to continue") + return forbiddenError("too_many_enrolled_mfa_factors", "Maximum number of enrolled factors reached, unenroll to continue") } key, err := totp.Generate(totp.GenerateOpts{ @@ -137,7 +137,7 @@ func (a *API) EnrollFactor(w http.ResponseWriter, r *http.Request) error { if terr := tx.Create(factor); terr != nil { pgErr := utilities.NewPostgresError(terr) if pgErr.IsUniqueConstraintViolated() { - return badRequestError(fmt.Sprintf("a factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) + return unprocessableEntityError("mfa_factor_name_conflict", fmt.Sprintf("A factor with the friendly name %q for this user likely already exists", factor.FriendlyName)) } return terr @@ -213,7 +213,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("invalid body: unable to parse JSON").WithInternalError(err) + return badRequestError("bad_json", "invalid body: unable to parse JSON").WithInternalError(err) } if !factor.IsOwnedBy(user) { @@ -223,13 +223,13 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { challenge, err := models.FindChallengeByChallengeID(a.db, params.ChallengeID) if err != nil { if models.IsNotFoundError(err) { - return notFoundError(err.Error()) + return notFoundError("mfa_factor_not_found", "MFA factor with the provided challenge ID not found") } return internalServerError("Database error finding Challenge").WithInternalError(err) } if challenge.VerifiedAt != nil || challenge.IPAddress != currentIP { - return badRequestError("Challenge and verify IP addresses mismatch") + return unprocessableEntityError("mfa_ip_address_mismatch", "Challenge and verify IP addresses mismatch") } if challenge.HasExpired(config.MFA.ChallengeExpiryDuration) { @@ -243,7 +243,7 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { if err != nil { return err } - return badRequestError("%v has expired, verify against another challenge or create a new challenge.", challenge.ID) + return unprocessableEntityError("mfa_challenge_expired", "MFA challenge %v has expired, verify against another challenge or create a new challenge.", challenge.ID) } valid := totp.Validate(params.Code, factor.Secret) @@ -271,11 +271,11 @@ func (a *API) VerifyFactor(w http.ResponseWriter, r *http.Request) error { output.Message = hooks.DefaultMFAHookRejectionMessage } - return forbiddenError(output.Message) + return forbiddenError("mfa_verification_rejected", output.Message) } } if !valid { - return badRequestError("Invalid TOTP code entered") + return unprocessableEntityError("mfa_verification_failed", "Invalid TOTP code entered") } var token *AccessTokenResponse @@ -336,7 +336,7 @@ func (a *API) UnenrollFactor(w http.ResponseWriter, r *http.Request) error { } if factor.IsVerified() && !session.IsAAL2() { - return badRequestError("AAL2 required to unenroll verified factor") + return unprocessableEntityError("insufficient_aal", "AAL2 required to unenroll verified factor") } if !factor.IsOwnedBy(user) { return internalServerError(InvalidFactorOwnerErrorMessage) diff --git a/internal/api/middleware.go b/internal/api/middleware.go index c0f3d857a1..f39024e216 100644 --- a/internal/api/middleware.go +++ b/internal/api/middleware.go @@ -106,7 +106,7 @@ func (a *API) limitEmailOrPhoneSentHandler() middlewareHandler { } if err := json.Unmarshal(bodyBytes, &requestBody); err != nil { - return c, badRequestError("Error invalid request body").WithInternalError(err) + return c, badRequestError("bad_json", "Error invalid request body").WithInternalError(err) } if shouldRateLimitEmail { @@ -156,7 +156,7 @@ func (a *API) requireEmailProvider(w http.ResponseWriter, req *http.Request) (co config := a.config if !config.External.Email.Enabled { - return nil, badRequestError("Email logins are disabled") + return nil, badRequestError("email_provider_disabled", "Email logins are disabled") } return ctx, nil @@ -183,8 +183,7 @@ func (a *API) verifyCaptcha(w http.ResponseWriter, req *http.Request) (context.C } if !verificationResult.Success { - return nil, badRequestError("captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) - + return nil, badRequestError("captcha_failed", "captcha protection: request disallowed (%s)", strings.Join(verificationResult.ErrorCodes, ", ")) } return ctx, nil @@ -228,7 +227,7 @@ func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (con func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.SAML.Enabled { - return nil, notFoundError("SAML 2.0 is disabled") + return nil, notFoundError("saml_provider_disabled", "SAML 2.0 is disabled") } return ctx, nil } @@ -236,7 +235,7 @@ func (a *API) requireSAMLEnabled(w http.ResponseWriter, req *http.Request) (cont func (a *API) requireManualLinkingEnabled(w http.ResponseWriter, req *http.Request) (context.Context, error) { ctx := req.Context() if !a.config.Security.ManualLinkingEnabled { - return nil, notFoundError("Manual linking is disabled") + return nil, notFoundError("manual_linking_disabled", "Manual linking is disabled") } return ctx, nil } diff --git a/internal/api/otp.go b/internal/api/otp.go index 0700f09705..e9e58337cb 100644 --- a/internal/api/otp.go +++ b/internal/api/otp.go @@ -34,10 +34,10 @@ type SmsParams struct { func (p *OtpParams) Validate() error { if p.Email != "" && p.Phone != "" { - return badRequestError("Only an email address or phone number should be provided") + return badRequestError("validation_failed", "Only an email address or phone number should be provided") } if p.Email != "" && p.Channel != "" { - return badRequestError("Channel should only be specified with Phone OTP") + return badRequestError("validation_failed", "Channel should only be specified with Phone OTP") } if err := validatePKCEParams(p.CodeChallengeMethod, p.CodeChallenge); err != nil { return err @@ -47,7 +47,7 @@ func (p *OtpParams) Validate() error { func (p *SmsParams) Validate(smsProvider string) error { if p.Phone != "" && !sms_provider.IsValidMessageChannel(p.Channel, smsProvider) { - return badRequestError(InvalidChannelError) + return badRequestError("validation_failed", InvalidChannelError) } var err error @@ -74,7 +74,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { } if err = json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read verification params: %v", err) + return badRequestError("bad_json", "Could not read verification params: %v", err) } if err := params.Validate(); err != nil { @@ -85,7 +85,7 @@ func (a *API) Otp(w http.ResponseWriter, r *http.Request) error { } if ok, err := a.shouldCreateUser(r, params); !ok { - return badRequestError("Signups not allowed for otp") + return badRequestError("signup_disabled", "Signups not allowed for otp") } else if err != nil { return err } @@ -110,7 +110,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { config := a.config if !config.External.Phone.Enabled { - return badRequestError("Unsupported phone provider") + return badRequestError("phone_provider_disabled", "Unsupported phone provider") } var err error @@ -118,11 +118,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read sms otp params: %v", err) + return badRequestError("bad_json", "Could not read sms otp params: %v", err) } // For backwards compatibility, we default to SMS if params Channel is not specified if params.Phone != "" && params.Channel == "" { @@ -151,7 +151,7 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { // Sign them up with temporary password. password, err := password.Generate(64, 10, 1, false, true) if err != nil { - internalServerError("error creating user").WithInternalError(err) + return internalServerError("error creating user").WithInternalError(err) } signUpParams := &SignupParams{ @@ -162,7 +162,8 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must be marshallable + panic(err) } r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) @@ -180,7 +181,8 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } newBodyContent, err := json.Marshal(signUpParams) if err != nil { - return badRequestError("Could not parse metadata: %v", err) + // SignupParams must be marshallable + panic(err) } r.Body = io.NopCloser(bytes.NewReader(newBodyContent)) return a.SmsOtp(w, r) @@ -201,11 +203,11 @@ func (a *API) SmsOtp(w http.ResponseWriter, r *http.Request) error { } smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Error finding SMS provider").WithInternalError(err) } mID, serr := a.sendPhoneConfirmation(ctx, tx, user, params.Phone, phoneConfirmationOtp, smsProvider, params.Channel) if serr != nil { - return badRequestError("Error sending sms OTP: %v", serr) + return badRequestError("sms_send_failed", "Error sending sms OTP: %v", serr).WithInternalError(serr) } messageID = mID return nil diff --git a/internal/api/phone.go b/internal/api/phone.go index 5d6e9bda30..9866664142 100644 --- a/internal/api/phone.go +++ b/internal/api/phone.go @@ -25,7 +25,7 @@ const ( func validatePhone(phone string) (string, error) { phone = formatPhoneNumber(phone) if isValid := validateE164Format(phone); !isValid { - return "", unprocessableEntityError("Invalid phone number format (E.164 required)") + return "", badRequestError("validation_failed", "Invalid phone number format (E.164 required)") } return phone, nil } diff --git a/internal/api/pkce.go b/internal/api/pkce.go index a186aa4649..9e05a54192 100644 --- a/internal/api/pkce.go +++ b/internal/api/pkce.go @@ -21,9 +21,9 @@ func isValidCodeChallenge(codeChallenge string) (bool, error) { // See RFC 7636 Section 4.2: https://www.rfc-editor.org/rfc/rfc7636#section-4.2 switch codeChallengeLength := len(codeChallenge); { case codeChallengeLength < MinCodeChallengeLength, codeChallengeLength > MaxCodeChallengeLength: - return false, badRequestError("code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) + return false, badRequestError("validation_failed", "code challenge has to be between %v and %v characters", MinCodeChallengeLength, MaxCodeChallengeLength) case !codeChallengePattern.MatchString(codeChallenge): - return false, badRequestError("code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") + return false, badRequestError("validation_failed", "code challenge can only contain alphanumeric characters, hyphens, periods, underscores and tildes") default: return true, nil } @@ -41,7 +41,7 @@ func addFlowPrefixToToken(token string, flowType models.FlowType) string { func issueAuthCode(tx *storage.Connection, user *models.User, expiryDuration time.Duration, authenticationMethod models.AuthenticationMethod) (string, error) { flowState, err := models.FindFlowStateByUserID(tx, user.ID.String(), authenticationMethod) if err != nil && models.IsNotFoundError(err) { - return "", badRequestError("No valid flow state found for user.") + return "", unprocessableEntityError("flow_state_not_found", "No valid flow state found for user.") } else if err != nil { return "", err } @@ -59,7 +59,7 @@ func isImplicitFlow(flowType models.FlowType) bool { func validatePKCEParams(codeChallengeMethod, codeChallenge string) error { switch true { case (codeChallenge == "") != (codeChallengeMethod == ""): - return badRequestError(InvalidPKCEParamsErrorMessage) + return badRequestError("validation_failed", InvalidPKCEParamsErrorMessage) case codeChallenge != "": if valid, err := isValidCodeChallenge(codeChallenge); !valid { return err diff --git a/internal/api/reauthenticate.go b/internal/api/reauthenticate.go index e29f2f7e32..ca552f6b9f 100644 --- a/internal/api/reauthenticate.go +++ b/internal/api/reauthenticate.go @@ -23,16 +23,16 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { email, phone := user.GetEmail(), user.GetPhone() if email == "" && phone == "" { - return unprocessableEntityError("Reauthentication requires the user to have an email or a phone number") + return badRequestError("validation_failed", "Reauthentication requires the user to have an email or a phone number") } if email != "" { if !user.IsConfirmed() { - return badRequestError("Please verify your email first.") + return unprocessableEntityError("email_not_confirmed", "Please verify your email first.") } } else if phone != "" { if !user.IsPhoneConfirmed() { - return badRequestError("Please verify your phone first.") + return unprocessableEntityError("phone_not_confirmed", "Please verify your phone first.") } } @@ -47,7 +47,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { } else if phone != "" { smsProvider, terr := sms_provider.GetSmsProvider(*config) if terr != nil { - return badRequestError("Error sending sms: %v", terr) + return internalServerError("Failed to get SMS provider").WithInternalError(terr) } mID, err := a.sendPhoneConfirmation(ctx, tx, user, phone, phoneReauthenticationOtp, smsProvider, sms_provider.SMSProvider) if err != nil { @@ -60,7 +60,12 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + reason := "over_email_send_rate" + if phone != "" { + reason = "over_sms_send_rate" + } + + return tooManyRequestsError(reason, "For security purposes, you can only request this once every 60 seconds") } return err } @@ -77,7 +82,7 @@ func (a *API) Reauthenticate(w http.ResponseWriter, r *http.Request) error { // verifyReauthentication checks if the nonce provided is valid func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, config *conf.GlobalConfiguration, user *models.User) error { if user.ReauthenticationToken == "" || user.ReauthenticationSentAt == nil { - return badRequestError(InvalidNonceMessage) + return badRequestError("CHECK", InvalidNonceMessage) } var isValid bool if user.GetEmail() != "" { @@ -87,7 +92,7 @@ func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, confi if config.Sms.IsTwilioVerifyProvider() { smsProvider, _ := sms_provider.GetSmsProvider(*config) if err := smsProvider.(*sms_provider.TwilioVerifyProvider).VerifyOTP(string(user.Phone), nonce); err != nil { - return expiredTokenError("Token has expired or is invalid").WithInternalError(err) + return expiredTokenError("CHECK", "Token has expired or is invalid").WithInternalError(err) } return nil } else { @@ -95,10 +100,10 @@ func (a *API) verifyReauthentication(nonce string, tx *storage.Connection, confi isValid = isOtpValid(tokenHash, user.ReauthenticationToken, user.ReauthenticationSentAt, config.Sms.OtpExp) } } else { - return unprocessableEntityError("Reauthentication requires an email or a phone number") + return unprocessableEntityError("CHECK", "Reauthentication requires an email or a phone number") } if !isValid { - return badRequestError(InvalidNonceMessage) + return badRequestError("CHECK", InvalidNonceMessage) } if err := user.ConfirmReauthentication(tx); err != nil { return internalServerError("Error during reauthentication").WithInternalError(err) diff --git a/internal/api/recover.go b/internal/api/recover.go index 9a57575650..49c98b8e76 100644 --- a/internal/api/recover.go +++ b/internal/api/recover.go @@ -19,7 +19,7 @@ type RecoverParams struct { func (p *RecoverParams) Validate() error { if p.Email == "" { - return unprocessableEntityError("Password recovery requires an email") + return badRequestError("validation_failed", "Password recovery requires an email") } var err error if p.Email, err = validateEmail(p.Email); err != nil { @@ -40,11 +40,11 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read verification params: %v", err) + return badRequestError("bad_json", "Could not read verification params: %v", err) } flowType := getFlowFromChallenge(params.CodeChallenge) @@ -83,7 +83,7 @@ func (a *API) Recover(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { - return tooManyRequestsError("For security purposes, you can only request this once every 60 seconds") + return tooManyRequestsError("over_email_send_rate", "For security purposes, you can only request this once every 60 seconds") } return internalServerError("Unable to process request").WithInternalError(err) } diff --git a/internal/api/resend.go b/internal/api/resend.go index a2fb4a52be..43cc46977a 100644 --- a/internal/api/resend.go +++ b/internal/api/resend.go @@ -26,22 +26,22 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er break default: // type does not match one of the above - return badRequestError("Missing one of these types: signup, email_change, sms, phone_change") + return badRequestError("validation_failed", "Missing one of these types: signup, email_change, sms, phone_change") } if p.Email == "" && p.Type == signupVerification { - return badRequestError("Type provided requires an email address") + return badRequestError("validation_failed", "Type provided requires an email address") } if p.Phone == "" && p.Type == smsVerification { - return badRequestError("Type provided requires a phone number") + return badRequestError("validation_failed", "Type provided requires a phone number") } var err error if p.Email != "" && p.Phone != "" { - return badRequestError("Only an email address or phone number should be provided.") + return badRequestError("validation_failed", "Only an email address or phone number should be provided.") } else if p.Email != "" { if !config.External.Email.Enabled { - return badRequestError("Email logins are disabled") + return badRequestError("email_provider_disabled", "Email logins are disabled") } p.Email, err = validateEmail(p.Email) if err != nil { @@ -49,7 +49,7 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er } } else if p.Phone != "" { if !config.External.Phone.Enabled { - return badRequestError("Phone logins are disabled") + return badRequestError("phone_provider_disabled", "Phone logins are disabled") } p.Phone, err = validatePhone(p.Phone) if err != nil { @@ -57,7 +57,7 @@ func (p *ResendConfirmationParams) Validate(config *conf.GlobalConfiguration) er } } else { // both email and phone are empty - return badRequestError("Missing email address or phone number") + return badRequestError("validation_failed", "Missing email address or phone number") } return nil } @@ -71,11 +71,11 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { body, err := getBodyBytes(r) if err != nil { - return badRequestError("Could not read body").WithInternalError(err) + return internalServerError("Could not read body").WithInternalError(err) } if err := json.Unmarshal(body, params); err != nil { - return badRequestError("Could not read params: %v", err) + return badRequestError("bad_json", "Could not read params: %v", err) } if err := params.Validate(config); err != nil { @@ -162,8 +162,13 @@ func (a *API) Resend(w http.ResponseWriter, r *http.Request) error { }) if err != nil { if errors.Is(err, MaxFrequencyLimitError) { + reason := "over_email_send_rate" + if params.Type == smsVerification || params.Type == phoneChangeVerification { + reason = "over_sms_send_rate" + } + until := time.Until(user.ConfirmationSentAt.Add(config.SMTP.MaxFrequency)) / time.Second - return tooManyRequestsError("For security purposes, you can only request this once every %d seconds.", until) + return tooManyRequestsError(reason, "For security purposes, you can only request this once every %d seconds.", until) } return internalServerError("Unable to process request").WithInternalError(err) } diff --git a/internal/models/flow_state.go b/internal/models/flow_state.go index 6aced0b59e..d18a5bd32b 100644 --- a/internal/models/flow_state.go +++ b/internal/models/flow_state.go @@ -81,7 +81,7 @@ func (FlowState) TableName() string { return tableName } -func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod) (*FlowState, error) { +func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod) *FlowState { id := uuid.Must(uuid.NewV4()) authCode := uuid.Must(uuid.NewV4()) flowState := &FlowState{ @@ -92,7 +92,7 @@ func NewFlowState(providerType, codeChallenge string, codeChallengeMethod CodeCh AuthCode: authCode.String(), AuthenticationMethod: authenticationMethod.String(), } - return flowState, nil + return flowState } func NewFlowStateWithUserID(tx *storage.Connection, providerType, codeChallenge string, codeChallengeMethod CodeChallengeMethod, authenticationMethod AuthenticationMethod, userID *uuid.UUID) error {