Skip to content

Commit

Permalink
feat: add error codes
Browse files Browse the repository at this point in the history
  • Loading branch information
hf committed Jan 18, 2024
1 parent 6284d99 commit b9799d5
Show file tree
Hide file tree
Showing 30 changed files with 379 additions and 363 deletions.
36 changes: 18 additions & 18 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,

userID, err := uuid.FromString(chi.URLParam(r, "user_id"))
if err != nil {
return nil, badRequestError("user_id must be an UUID")
return nil, badRequestError("validation_failed", "user_id must be an UUID")
}

observability.LogEntrySetField(r, "user_id", userID)

u, err := models.FindUserByID(db, userID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError("User not found")
return nil, notFoundError("not_found", "User not found")
}
return nil, internalServerError("Database error loading user").WithInternalError(err)
}
Expand All @@ -69,15 +69,15 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,
func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) {
factorID, err := uuid.FromString(chi.URLParam(r, "factor_id"))
if err != nil {
return nil, badRequestError("factor_id must be an UUID")
return nil, badRequestError("validation_failed", "factor_id must be an UUID")
}

observability.LogEntrySetField(r, "factor_id", factorID)

f, err := models.FindFactorByFactorID(a.db, factorID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError("Factor not found")
return nil, notFoundError("not_found", "Factor not found")
}
return nil, internalServerError("Database error loading factor").WithInternalError(err)
}
Expand All @@ -89,11 +89,11 @@ func (a *API) getAdminParams(r *http.Request) (*AdminUserParams, error) {

body, err := getBodyBytes(r)
if err != nil {
return nil, badRequestError("Could not read body").WithInternalError(err)
return nil, internalServerError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, &params); err != nil {
return nil, badRequestError("Could not decode admin user params: %v", err)
return nil, badRequestError("bad_json", "Could not decode admin user params").WithInternalError(err)
}

return &params, nil
Expand All @@ -107,12 +107,12 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error {

pageParams, err := paginate(r)
if err != nil {
return badRequestError("Bad Pagination Parameters: %v", err)
return badRequestError("validation_failed", "Bad Pagination Parameters: %v", err).WithInternalError(err)
}

sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}})
if err != nil {
return badRequestError("Bad Sort Parameters: %v", err)
return badRequestError("validation_failed", "Bad Sort Parameters: %v", err)
}

filter := r.URL.Query().Get("filter")
Expand Down Expand Up @@ -166,7 +166,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
if params.BanDuration != "none" {
duration, err = time.ParseDuration(params.BanDuration)
if err != nil {
return badRequestError("invalid format for ban duration: %v", err)
return badRequestError("validation_failed", "invalid format for ban duration: %v", err)
}
}
if terr := user.Ban(a.db, duration); terr != nil {
Expand Down Expand Up @@ -314,7 +314,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
}

if params.Email == "" && params.Phone == "" {
return unprocessableEntityError("Cannot create a user without either an email or phone")
return badRequestError("validation_failed", "Cannot create a user without either an email or phone")
}

var providers []string
Expand All @@ -326,7 +326,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil {
return internalServerError("Database error checking email").WithInternalError(err)
} else if user != nil {
return unprocessableEntityError(DuplicateEmailMsg)
return unprocessableEntityError("email_exists", DuplicateEmailMsg)
}
providers = append(providers, "email")
}
Expand All @@ -339,7 +339,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil {
return internalServerError("Database error checking phone").WithInternalError(err)
} else if exists {
return unprocessableEntityError("Phone number already registered by another user")
return unprocessableEntityError("phone_exists", "Phone number already registered by another user")
}
providers = append(providers, "phone")
}
Expand Down Expand Up @@ -435,7 +435,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if params.BanDuration != "none" {
duration, err = time.ParseDuration(params.BanDuration)
if err != nil {
return badRequestError("invalid format for ban duration: %v", err)
return badRequestError("validation_failed", "invalid format for ban duration: %v", err)
}
}
if terr := user.Ban(a.db, duration); terr != nil {
Expand Down Expand Up @@ -466,11 +466,11 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
params := &adminUserDeleteParams{}
body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
return internalServerError("Could not read body").WithInternalError(err)
}
if len(body) > 0 {
if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read params: %v", err)
return badRequestError("bad_json", "Could not read params: %v", err)
}
} else {
params.ShouldSoftDelete = false
Expand Down Expand Up @@ -567,11 +567,11 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
params := &adminUserUpdateFactorParams{}
body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
return internalServerError("Could not read body").WithInternalError(err)
}

if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read factor update params: %v", err)
return badRequestError("bad_json", "Could not read factor update params: %v", err).WithInternalError(err)
}

err = a.db.Transaction(func(tx *storage.Connection) error {
Expand All @@ -582,7 +582,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
}
if params.FactorType != "" {
if params.FactorType != models.TOTP {
return badRequestError("Factor Type not valid")
return badRequestError("validation_failed", "Factor Type not valid")
}
if terr := factor.UpdateFactorType(tx, params.FactorType); terr != nil {
return terr
Expand Down
4 changes: 2 additions & 2 deletions internal/api/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
// aud := a.requestAud(ctx, r)
pageParams, err := paginate(r)
if err != nil {
return badRequestError("Bad Pagination Parameters: %v", err)
return badRequestError("validation_failed", "Bad Pagination Parameters: %v", err)
}

var col []string
Expand All @@ -31,7 +31,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
qparts := strings.SplitN(q, ":", 2)
col, exists = filterColumnMap[qparts[0]]
if !exists || len(qparts) < 2 {
return badRequestError("Invalid query scope: %s", q)
return badRequestError("validation_failed", "Invalid query scope: %s", q)
}
qval = qparts[1]
}
Expand Down
20 changes: 10 additions & 10 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.R
claims := getClaims(ctx)
if claims == nil {
fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "Invalid token")
return nil, unauthorizedError("Invalid token")
return nil, unauthorizedError("bad_jwt", "Invalid token")
}

adminRoles := a.config.JWT.AdminRoles
Expand All @@ -51,14 +51,14 @@ func (a *API) requireAdmin(ctx context.Context, w http.ResponseWriter, r *http.R
}

fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "this token needs role 'supabase_admin' or 'service_role'")
return nil, unauthorizedError("User not allowed")
return nil, unauthorizedError("not_admin", "User not allowed")
}

func (a *API) extractBearerToken(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
matches := bearerRegexp.FindStringSubmatch(authHeader)
if len(matches) != 2 {
return "", unauthorizedError("This endpoint requires a Bearer token")
return "", unauthorizedError("no_authorization", "This endpoint requires a Bearer token")
}

return matches[1], nil
Expand All @@ -73,7 +73,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e
return []byte(config.JWT.Secret), nil
})
if err != nil {
return nil, unauthorizedError("invalid JWT: unable to parse or verify signature, %v", err)
return nil, unauthorizedError("bad_jwt", "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err)
}

return withToken(ctx, token), nil
Expand All @@ -84,23 +84,23 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro
claims := getClaims(ctx)

if claims == nil {
return ctx, unauthorizedError("invalid token: missing claims")
return ctx, unauthorizedError("bad_jwt", "invalid token: missing claims")
}

if claims.Subject == "" {
return nil, unauthorizedError("invalid claim: missing sub claim")
return nil, unauthorizedError("bad_jwt", "invalid claim: missing sub claim")
}

var user *models.User
if claims.Subject != "" {
userId, err := uuid.FromString(claims.Subject)
if err != nil {
return ctx, badRequestError("invalid claim: sub claim must be a UUID").WithInternalError(err)
return ctx, badRequestError("bad_jwt", "invalid claim: sub claim must be a UUID").WithInternalError(err)
}
user, err = models.FindUserByID(db, userId)
if err != nil {
if models.IsNotFoundError(err) {
return ctx, notFoundError(err.Error())
return ctx, unauthorizedError("user_not_found", "User from sub claim in JWT does not exist")
}
return ctx, err
}
Expand All @@ -111,11 +111,11 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro
if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() {
sessionId, err := uuid.FromString(claims.SessionId)
if err != nil {
return ctx, err
return ctx, badRequestError("bad_jwt", "invalid claim: session_id claim must be a UUID").WithInternalError(err)
}
session, err = models.FindSessionByID(db, sessionId, false)
if err != nil && !models.IsNotFoundError(err) {
return ctx, err
return ctx, unauthorizedError("session_not_found", "Session from session_id claim in JWT does not exist")
}
ctx = withSession(ctx, session)
}
Expand Down
74 changes: 7 additions & 67 deletions internal/api/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"runtime/debug"

"github.com/pkg/errors"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/utilities"
)
Expand Down Expand Up @@ -65,54 +64,39 @@ func (e *OAuthError) Cause() error {
return e
}

func invalidSignupError(config *conf.GlobalConfiguration) *HTTPError {
var msg string
if config.External.Email.Enabled && config.External.Phone.Enabled {
msg = "To signup, please provide your email or phone number"
} else if config.External.Email.Enabled {
msg = "To signup, please provide your email"
} else if config.External.Phone.Enabled {
msg = "To signup, please provide your phone number"
} else {
// 3rd party OAuth signups
msg = "To signup, please provide required fields"
}
return unprocessableEntityError(msg)
}

func oauthError(err string, description string) *OAuthError {
return &OAuthError{Err: err, Description: description}
}

func badRequestError(fmtString string, args ...interface{}) *HTTPError {
func badRequestError(reason, fmtString string, args ...interface{}) *HTTPError {
return httpError(http.StatusBadRequest, fmtString, args...)
}

func internalServerError(fmtString string, args ...interface{}) *HTTPError {
return httpError(http.StatusInternalServerError, fmtString, args...)
}

func notFoundError(fmtString string, args ...interface{}) *HTTPError {
func notFoundError(reason, fmtString string, args ...interface{}) *HTTPError {
return httpError(http.StatusNotFound, fmtString, args...)
}

func expiredTokenError(fmtString string, args ...interface{}) *HTTPError {
func expiredTokenError(reason, fmtString string, args ...interface{}) *HTTPError {
return httpError(http.StatusUnauthorized, fmtString, args...)
}

func unauthorizedError(fmtString string, args ...interface{}) *HTTPError {
func unauthorizedError(reason, fmtString string, args ...interface{}) *HTTPError {
return httpError(http.StatusUnauthorized, fmtString, args...)
}

func forbiddenError(fmtString string, args ...interface{}) *HTTPError {
func forbiddenError(reason, fmtString string, args ...interface{}) *HTTPError {
return httpError(http.StatusForbidden, fmtString, args...)
}

func unprocessableEntityError(fmtString string, args ...interface{}) *HTTPError {
func unprocessableEntityError(reason, fmtString string, args ...interface{}) *HTTPError {
return httpError(http.StatusUnprocessableEntity, fmtString, args...)
}

func tooManyRequestsError(fmtString string, args ...interface{}) *HTTPError {
func tooManyRequestsError(reason, fmtString string, args ...interface{}) *HTTPError {
return httpError(http.StatusTooManyRequests, fmtString, args...)
}

Expand Down Expand Up @@ -167,45 +151,6 @@ func httpError(code int, fmtString string, args ...interface{}) *HTTPError {
}
}

// OTPError is a custom error struct for phone auth errors
type OTPError struct {
Err string `json:"error"`
Description string `json:"error_description,omitempty"`
InternalError error `json:"-"`
InternalMessage string `json:"-"`
}

func (e *OTPError) Error() string {
if e.InternalMessage != "" {
return e.InternalMessage
}
return fmt.Sprintf("%s: %s", e.Err, e.Description)
}

// WithInternalError adds internal error information to the error
func (e *OTPError) WithInternalError(err error) *OTPError {
e.InternalError = err
return e
}

// WithInternalMessage adds internal message information to the error
func (e *OTPError) WithInternalMessage(fmtString string, args ...interface{}) *OTPError {
e.InternalMessage = fmt.Sprintf(fmtString, args...)
return e
}

// Cause returns the root cause error
func (e *OTPError) Cause() error {
if e.InternalError != nil {
return e.InternalError
}
return e
}

func otpError(err string, description string) *OTPError {
return &OTPError{Err: err, Description: description}
}

// Recoverer is a middleware that recovers from panics, logs the panic (and a
// backtrace), and returns a HTTP 500 (Internal Server Error) status if
// possible. Recoverer prints a request ID if one is provided.
Expand Down Expand Up @@ -282,11 +227,6 @@ func handleError(err error, w http.ResponseWriter, r *http.Request) {
if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil {
handleError(jsonErr, w, r)
}
case *OTPError:
log.WithError(e.Cause()).Info(e.Error())
if jsonErr := sendJSON(w, http.StatusBadRequest, e); jsonErr != nil {
handleError(jsonErr, w, r)
}
case ErrorCause:
handleError(e.Cause(), w, r)
default:
Expand Down
Loading

0 comments on commit b9799d5

Please sign in to comment.