Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion internal/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"slices"
"strings"

"github.com/gofrs/uuid"
Expand Down Expand Up @@ -51,7 +52,7 @@ func (a *API) requireAdmin(ctx context.Context) (context.Context, error) {

adminRoles := a.config.JWT.AdminRoles

if isStringInSlice(claims.Role, adminRoles) {
if slices.Contains(adminRoles, claims.Role) {
// successful authentication
return withAdminUser(ctx, &models.User{Role: claims.Role, Email: storage.NullString(claims.Role)}), nil
}
Expand Down
12 changes: 2 additions & 10 deletions internal/api/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"net/http"
"slices"

"github.com/supabase/auth/internal/api/apierrors"
"github.com/supabase/auth/internal/api/shared"
Expand Down Expand Up @@ -34,7 +35,7 @@ func (a *API) requestAud(ctx context.Context, r *http.Request) string {

// ignore the JWT's aud claim if the role is admin
// this is because anon, service_role never had an aud claim to begin with
if claims != nil && !isStringInSlice(claims.Role, config.JWT.AdminRoles) {
if claims != nil && !slices.Contains(config.JWT.AdminRoles, claims.Role) {
aud, _ := claims.GetAudience()
if len(aud) != 0 && aud[0] != "" {
return aud[0]
Expand All @@ -45,15 +46,6 @@ func (a *API) requestAud(ctx context.Context, r *http.Request) string {
return config.JWT.Aud
}

func isStringInSlice(checkValue string, list []string) bool {
for _, val := range list {
if val == checkValue {
return true
}
}
return false
}

type RequestParams interface {
AdminUserParams |
CreateSSOProviderParams |
Expand Down
15 changes: 5 additions & 10 deletions internal/api/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"net/url"
"slices"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -214,18 +215,12 @@ func (a *API) isValidExternalHost(w http.ResponseWriter, req *http.Request) (con
protocol := "https"

if xForwardedHost != "" {
for _, host := range config.Mailer.ExternalHosts {
if host == xForwardedHost {
hostname = host
break
}
if slices.Contains(config.Mailer.ExternalHosts, xForwardedHost) {
hostname = xForwardedHost
}
} else if reqHost != "" {
for _, host := range config.Mailer.ExternalHosts {
if host == reqHost {
hostname = host
break
}
if slices.Contains(config.Mailer.ExternalHosts, reqHost) {
hostname = reqHost
}
}

Expand Down
24 changes: 8 additions & 16 deletions internal/api/token_oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"crypto/sha256"
"fmt"
"net/http"
"slices"

"github.com/coreos/go-oidc/v3/oidc"
"github.com/supabase/auth/internal/api/apierrors"
Expand Down Expand Up @@ -112,14 +113,11 @@ func (p *IdTokenGrantParams) getProvider(ctx context.Context, config *conf.Globa
log.WithField("issuer", p.Issuer).WithField("client_id", p.ClientID).Warn("Use of POST /token with arbitrary issuer and client_id is deprecated for security reasons. Please switch to using the API with provider only!")

allowed := false
for _, allowedIssuer := range config.External.AllowedIdTokenIssuers {
if p.Issuer == allowedIssuer {
allowed = true
providerType = allowedIssuer
acceptableClientIDs = []string{p.ClientID}
issuer = allowedIssuer
break
}
if slices.Contains(config.External.AllowedIdTokenIssuers, p.Issuer) {
allowed = true
providerType = p.Issuer
acceptableClientIDs = []string{p.ClientID}
issuer = p.Issuer
}

if !allowed {
Expand Down Expand Up @@ -213,14 +211,8 @@ func (a *API) IdTokenGrant(ctx context.Context, w http.ResponseWriter, r *http.R
continue
}

for _, aud := range idToken.Audience {
if aud == clientID {
correctAudience = true
break
}
}

if correctAudience {
if slices.Contains(idToken.Audience, clientID) {
correctAudience = true
break
}
}
Expand Down
16 changes: 6 additions & 10 deletions internal/conf/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package conf
import (
"encoding/json"
"fmt"
"slices"

"github.com/golang-jwt/jwt/v5"
"github.com/lestrrat-go/jwx/v2/jwk"
Expand Down Expand Up @@ -97,11 +98,8 @@ func (j *JwtKeysDecoder) Validate() error {
}
}

for _, op := range key.PrivateKey.KeyOps() {
if op == jwk.KeyOpSign {
signingKeys = append(signingKeys, key.PrivateKey)
break
}
if slices.Contains(key.PrivateKey.KeyOps(), jwk.KeyOpSign) {
signingKeys = append(signingKeys, key.PrivateKey)
}
}

Expand All @@ -117,11 +115,9 @@ func (j *JwtKeysDecoder) Validate() error {

func GetSigningJwk(config *JWTConfiguration) (jwk.Key, error) {
for _, key := range config.Keys {
for _, op := range key.PrivateKey.KeyOps() {
// the private JWK with key_ops "sign" should be used as the signing key
if op == jwk.KeyOpSign {
return key.PrivateKey, nil
}
// the private JWK with key_ops "sign" should be used as the signing key
if slices.Contains(key.PrivateKey.KeyOps(), jwk.KeyOpSign) {
return key.PrivateKey, nil
}
}
return nil, fmt.Errorf("no signing key found")
Expand Down
7 changes: 3 additions & 4 deletions internal/models/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package models
import (
"database/sql"
"fmt"
"slices"
"sort"
"strings"
"time"
Expand Down Expand Up @@ -174,10 +175,8 @@ func (s *Session) DetermineTag(tags []string) string {
return tags[0]
}

for _, t := range tags {
if t == tag {
return tag
}
if slices.Contains(tags, tag) {
return tag
}

return tags[0]
Expand Down