Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add error codes #1377

Merged
merged 5 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .github/workflows/conventional-commits-lint.js
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ let failed = false;

validate.forEach((payload) => {
if (payload.title) {
const { groups } = payload.title.match(TITLE_PATTERN);
const match = payload.title.match(TITLE_PATTERN);
if (!match) {
return
}

const { groups } = match

if (groups) {
if (groups.breaking) {
Expand Down
29 changes: 15 additions & 14 deletions internal/api/admin.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,

userID, err := uuid.FromString(chi.URLParam(r, "user_id"))
if err != nil {
return nil, badRequestError("user_id must be an UUID")
return nil, notFoundError(ErrorCodeValidationFailed, "user_id must be an UUID")
}

observability.LogEntrySetField(r, "user_id", userID)

u, err := models.FindUserByID(db, userID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError("User not found")
return nil, notFoundError(ErrorCodeUserNotFound, "User not found")
}
return nil, internalServerError("Database error loading user").WithInternalError(err)
}
Expand All @@ -69,15 +69,15 @@ func (a *API) loadUser(w http.ResponseWriter, r *http.Request) (context.Context,
func (a *API) loadFactor(w http.ResponseWriter, r *http.Request) (context.Context, error) {
factorID, err := uuid.FromString(chi.URLParam(r, "factor_id"))
if err != nil {
return nil, badRequestError("factor_id must be an UUID")
return nil, notFoundError(ErrorCodeValidationFailed, "factor_id must be an UUID")
}

observability.LogEntrySetField(r, "factor_id", factorID)

f, err := models.FindFactorByFactorID(a.db, factorID)
if err != nil {
if models.IsNotFoundError(err) {
return nil, notFoundError("Factor not found")
return nil, notFoundError(ErrorCodeMFAFactorNotFound, "Factor not found")
}
return nil, internalServerError("Database error loading factor").WithInternalError(err)
}
Expand All @@ -101,12 +101,12 @@ func (a *API) adminUsers(w http.ResponseWriter, r *http.Request) error {

pageParams, err := paginate(r)
if err != nil {
return badRequestError("Bad Pagination Parameters: %v", err)
return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err).WithInternalError(err)
}

sortParams, err := sort(r, map[string]bool{models.CreatedAt: true}, []models.SortField{{Name: models.CreatedAt, Dir: models.Descending}})
if err != nil {
return badRequestError("Bad Sort Parameters: %v", err)
return badRequestError(ErrorCodeValidationFailed, "Bad Sort Parameters: %v", err)
}

filter := r.URL.Query().Get("filter")
Expand Down Expand Up @@ -160,7 +160,7 @@ func (a *API) adminUserUpdate(w http.ResponseWriter, r *http.Request) error {
if params.BanDuration != "none" {
duration, err = time.ParseDuration(params.BanDuration)
if err != nil {
return badRequestError("invalid format for ban duration: %v", err)
return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
}
}
if terr := user.Ban(a.db, duration); terr != nil {
Expand Down Expand Up @@ -308,7 +308,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
}

if params.Email == "" && params.Phone == "" {
return unprocessableEntityError("Cannot create a user without either an email or phone")
return badRequestError(ErrorCodeValidationFailed, "Cannot create a user without either an email or phone")
}

var providers []string
Expand All @@ -320,7 +320,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if user, err := models.IsDuplicatedEmail(db, params.Email, aud, nil); err != nil {
return internalServerError("Database error checking email").WithInternalError(err)
} else if user != nil {
return unprocessableEntityError(DuplicateEmailMsg)
return unprocessableEntityError(ErrorCodeEmailExists, DuplicateEmailMsg)
}
providers = append(providers, "email")
}
Expand All @@ -333,7 +333,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if exists, err := models.IsDuplicatedPhone(db, params.Phone, aud); err != nil {
return internalServerError("Database error checking phone").WithInternalError(err)
} else if exists {
return unprocessableEntityError("Phone number already registered by another user")
return unprocessableEntityError(ErrorCodePhoneExists, "Phone number already registered by another user")
}
providers = append(providers, "phone")
}
Expand Down Expand Up @@ -429,7 +429,7 @@ func (a *API) adminUserCreate(w http.ResponseWriter, r *http.Request) error {
if params.BanDuration != "none" {
duration, err = time.ParseDuration(params.BanDuration)
if err != nil {
return badRequestError("invalid format for ban duration: %v", err)
return badRequestError(ErrorCodeValidationFailed, "invalid format for ban duration: %v", err)
}
}
if terr := user.Ban(a.db, duration); terr != nil {
Expand Down Expand Up @@ -460,11 +460,11 @@ func (a *API) adminUserDelete(w http.ResponseWriter, r *http.Request) error {
params := &adminUserDeleteParams{}
body, err := getBodyBytes(r)
if err != nil {
return badRequestError("Could not read body").WithInternalError(err)
return internalServerError("Could not read body").WithInternalError(err)
}
if len(body) > 0 {
if err := json.Unmarshal(body, params); err != nil {
return badRequestError("Could not read params: %v", err)
return badRequestError(ErrorCodeBadJSON, "Could not read params: %v", err)
}
} else {
params.ShouldSoftDelete = false
Expand Down Expand Up @@ -559,6 +559,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
user := getUser(ctx)
adminUser := getAdminUser(ctx)
params := &adminUserUpdateFactorParams{}

if err := retrieveRequestParams(r, params); err != nil {
return err
}
Expand All @@ -571,7 +572,7 @@ func (a *API) adminUserUpdateFactor(w http.ResponseWriter, r *http.Request) erro
}
if params.FactorType != "" {
if params.FactorType != models.TOTP {
return badRequestError("Factor Type not valid")
return badRequestError(ErrorCodeValidationFailed, "Factor Type not valid")
}
if terr := factor.UpdateFactorType(tx, params.FactorType); terr != nil {
return terr
Expand Down
2 changes: 1 addition & 1 deletion internal/api/anonymous.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func (a *API) SignupAnonymously(w http.ResponseWriter, r *http.Request) error {
aud := a.requestAud(ctx, r)

if config.DisableSignup {
return forbiddenError("Signups not allowed for this instance")
return unprocessableEntityError(ErrorCodeSignupDisabled, "Signups not allowed for this instance")
}

params := &SignupParams{}
Expand Down
2 changes: 1 addition & 1 deletion internal/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ func NewAPIWithVersion(ctx context.Context, globalConfig *conf.GlobalConfigurati
}
if params.Email == "" && params.Phone == "" {
if !api.config.External.AnonymousUsers.Enabled {
return unprocessableEntityError("Anonymous sign-ins are disabled")
return unprocessableEntityError(ErrorCodeAnonymousProviderDisabled, "Anonymous sign-ins are disabled")
}
if _, err := api.limitHandler(limiter)(w, r); err != nil {
return err
Expand Down
35 changes: 35 additions & 0 deletions internal/api/apiversions.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package api

import (
"time"
)

const APIVersionHeaderName = "X-Supabase-Api-Version"

type APIVersion = time.Time

var (
APIVersionInitial = time.Time{}
APIVersion20240101 = time.Date(2024, time.January, 1, 0, 0, 0, 0, time.UTC)
)

func DetermineClosestAPIVersion(date string) (APIVersion, error) {
if date == "" {
return APIVersionInitial, nil
}

parsed, err := time.ParseInLocation("2006-01-02", date, time.UTC)
if err != nil {
return APIVersionInitial, err
}

if parsed.Compare(APIVersion20240101) >= 0 {
return APIVersion20240101, nil
}

return APIVersionInitial, nil
}

func FormatAPIVersion(apiVersion APIVersion) string {
return apiVersion.Format("2006-01-02")
}
29 changes: 29 additions & 0 deletions internal/api/apiversions_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package api

import (
"testing"

"github.com/stretchr/testify/require"
)

func TestDetermineClosestAPIVersion(t *testing.T) {
version, err := DetermineClosestAPIVersion("")
require.NoError(t, err)
require.Equal(t, APIVersionInitial, version)

version, err = DetermineClosestAPIVersion("Not a date")
require.Error(t, err)
require.Equal(t, APIVersionInitial, version)

version, err = DetermineClosestAPIVersion("2023-12-31")
require.NoError(t, err)
require.Equal(t, APIVersionInitial, version)

version, err = DetermineClosestAPIVersion("2024-01-01")
require.NoError(t, err)
require.Equal(t, APIVersion20240101, version)

version, err = DetermineClosestAPIVersion("2024-01-02")
require.NoError(t, err)
require.Equal(t, APIVersion20240101, version)
}
4 changes: 2 additions & 2 deletions internal/api/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
// aud := a.requestAud(ctx, r)
pageParams, err := paginate(r)
if err != nil {
return badRequestError("Bad Pagination Parameters: %v", err)
return badRequestError(ErrorCodeValidationFailed, "Bad Pagination Parameters: %v", err)
}

var col []string
Expand All @@ -31,7 +31,7 @@ func (a *API) adminAuditLog(w http.ResponseWriter, r *http.Request) error {
qparts := strings.SplitN(q, ":", 2)
col, exists = filterColumnMap[qparts[0]]
if !exists || len(qparts) < 2 {
return badRequestError("Invalid query scope: %s", q)
return badRequestError(ErrorCodeValidationFailed, "Invalid query scope: %s", q)
}
qval = qparts[1]
}
Expand Down
22 changes: 11 additions & 11 deletions internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (a *API) requireNotAnonymous(w http.ResponseWriter, r *http.Request) (conte
ctx := r.Context()
claims := getClaims(ctx)
if claims.IsAnonymous {
return nil, forbiddenError("Anonymous user not allowed to perform these actions")
return nil, forbiddenError(ErrorCodeNoAuthorization, "Anonymous user not allowed to perform these actions")
}
return ctx, nil
}
Expand All @@ -49,7 +49,7 @@ func (a *API) requireAdmin(ctx context.Context, r *http.Request) (context.Contex
claims := getClaims(ctx)
if claims == nil {
fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "Invalid token")
return nil, unauthorizedError("Invalid token")
return nil, forbiddenError(ErrorCodeBadJWT, "Invalid token")
}

adminRoles := a.config.JWT.AdminRoles
Expand All @@ -60,14 +60,14 @@ func (a *API) requireAdmin(ctx context.Context, r *http.Request) (context.Contex
}

fmt.Printf("[%s] %s %s %d %s\n", time.Now().Format("2006-01-02 15:04:05"), r.Method, r.RequestURI, http.StatusForbidden, "this token needs role 'supabase_admin' or 'service_role'")
return nil, unauthorizedError("User not allowed")
return nil, forbiddenError(ErrorCodeNotAdmin, "User not allowed")
}

func (a *API) extractBearerToken(r *http.Request) (string, error) {
authHeader := r.Header.Get("Authorization")
matches := bearerRegexp.FindStringSubmatch(authHeader)
if len(matches) != 2 {
return "", unauthorizedError("This endpoint requires a Bearer token")
return "", httpError(http.StatusUnauthorized, ErrorCodeNoAuthorization, "This endpoint requires a Bearer token")
}

return matches[1], nil
Expand All @@ -82,7 +82,7 @@ func (a *API) parseJWTClaims(bearer string, r *http.Request) (context.Context, e
return []byte(config.JWT.Secret), nil
})
if err != nil {
return nil, unauthorizedError("invalid JWT: unable to parse or verify signature, %v", err)
return nil, forbiddenError(ErrorCodeBadJWT, "invalid JWT: unable to parse or verify signature, %v", err).WithInternalError(err)
}

return withToken(ctx, token), nil
Expand All @@ -93,23 +93,23 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro
claims := getClaims(ctx)

if claims == nil {
return ctx, unauthorizedError("invalid token: missing claims")
return ctx, forbiddenError(ErrorCodeBadJWT, "invalid token: missing claims")
}

if claims.Subject == "" {
return nil, unauthorizedError("invalid claim: missing sub claim")
return nil, forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim")
}

var user *models.User
if claims.Subject != "" {
userId, err := uuid.FromString(claims.Subject)
if err != nil {
return ctx, badRequestError("invalid claim: sub claim must be a UUID").WithInternalError(err)
return ctx, badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID").WithInternalError(err)
J0 marked this conversation as resolved.
Show resolved Hide resolved
}
user, err = models.FindUserByID(db, userId)
if err != nil {
if models.IsNotFoundError(err) {
return ctx, notFoundError(err.Error())
return ctx, forbiddenError(ErrorCodeUserNotFound, "User from sub claim in JWT does not exist")
}
return ctx, err
}
Expand All @@ -120,11 +120,11 @@ func (a *API) maybeLoadUserOrSession(ctx context.Context) (context.Context, erro
if claims.SessionId != "" && claims.SessionId != uuid.Nil.String() {
sessionId, err := uuid.FromString(claims.SessionId)
if err != nil {
return ctx, err
return ctx, forbiddenError(ErrorCodeBadJWT, "invalid claim: session_id claim must be a UUID").WithInternalError(err)
}
session, err = models.FindSessionByID(db, sessionId, false)
if err != nil && !models.IsNotFoundError(err) {
return ctx, err
return ctx, forbiddenError(ErrorCodeSessionNotFound, "Session from session_id claim in JWT does not exist")
}
ctx = withSession(ctx, session)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/api/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() {
},
Role: "authenticated",
},
ExpectedError: unauthorizedError("invalid claim: missing sub claim"),
ExpectedError: forbiddenError(ErrorCodeBadJWT, "invalid claim: missing sub claim"),
ExpectedUser: nil,
},
{
Expand All @@ -118,7 +118,7 @@ func (ts *AuthTestSuite) TestMaybeLoadUserOrSession() {
},
Role: "authenticated",
},
ExpectedError: badRequestError("invalid claim: sub claim must be a UUID"),
ExpectedError: badRequestError(ErrorCodeBadJWT, "invalid claim: sub claim must be a UUID"),
ExpectedUser: nil,
},
{
Expand Down
Loading
Loading