Skip to content

Commit

Permalink
fix: use signing jwk to sign oauth state (supabase#1728)
Browse files Browse the repository at this point in the history
## What kind of change does this PR introduce?
* OAuth state is now signed with the same JWK that is used to sign the
access tokens

## What is the current behavior?
* currently, it's weird for the `GOTRUE_JWT_SECRET` to be set (other
than it being a fallback option) just for the sake of signing the oauth
state

## What is the new behavior?

Feel free to include screenshots if it includes visual changes.

## Additional context

Add any other context or screenshots.
  • Loading branch information
kangmingtay authored Aug 21, 2024
1 parent 4d0e935 commit 5f7ac2a
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 29 deletions.
19 changes: 15 additions & 4 deletions internal/api/external.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
jwt "github.com/golang-jwt/jwt/v5"
"github.com/sirupsen/logrus"
"github.com/supabase/auth/internal/api/provider"
"github.com/supabase/auth/internal/conf"
"github.com/supabase/auth/internal/models"
"github.com/supabase/auth/internal/observability"
"github.com/supabase/auth/internal/storage"
Expand Down Expand Up @@ -106,8 +107,7 @@ func (a *API) GetExternalProviderRedirectURL(w http.ResponseWriter, r *http.Requ
claims.LinkingTargetID = linkingTargetUser.ID.String()
}

token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(config.JWT.Secret))
tokenString, err := signJwt(&config.JWT, claims)
if err != nil {
return "", internalServerError("Error creating state").WithInternalError(err)
}
Expand Down Expand Up @@ -491,9 +491,20 @@ func (a *API) loadExternalState(ctx context.Context, r *http.Request) (context.C
}
config := a.config
claims := ExternalProviderClaims{}
p := jwt.NewParser(jwt.WithValidMethods([]string{jwt.SigningMethodHS256.Name}))
p := jwt.NewParser(jwt.WithValidMethods(config.JWT.ValidMethods))
_, err := p.ParseWithClaims(state, &claims, func(token *jwt.Token) (interface{}, error) {
return []byte(config.JWT.Secret), nil
if kid, ok := token.Header["kid"]; ok {
if kidStr, ok := kid.(string); ok {
return conf.FindPublicKeyByKid(kidStr, &config.JWT)
}
}
if alg, ok := token.Header["alg"]; ok {
if alg == jwt.SigningMethodHS256.Name {
// preserve backward compatibility for cases where the kid is not set
return []byte(config.JWT.Secret), nil
}
}
return nil, fmt.Errorf("missing kid")
})
if err != nil {
return ctx, badRequestError(ErrorCodeBadOAuthState, "OAuth callback with invalid state").WithInternalError(err)
Expand Down
31 changes: 31 additions & 0 deletions internal/api/jwks.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@ package api
import (
"net/http"

jwt "github.com/golang-jwt/jwt/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
jwk "github.com/lestrrat-go/jwx/v2/jwk"
"github.com/supabase/auth/internal/conf"
)

type JwksResponse struct {
Expand All @@ -28,3 +30,32 @@ func (a *API) Jwks(w http.ResponseWriter, r *http.Request) error {
w.Header().Set("Cache-Control", "public, max-age=600")
return sendJSON(w, http.StatusOK, resp)
}

func signJwt(config *conf.JWTConfiguration, claims jwt.Claims) (string, error) {
signingJwk, err := conf.GetSigningJwk(config)
if err != nil {
return "", err
}
signingMethod := conf.GetSigningAlg(signingJwk)
token := jwt.NewWithClaims(signingMethod, claims)
if token.Header == nil {
token.Header = make(map[string]interface{})
}

if _, ok := token.Header["kid"]; !ok {
if kid := signingJwk.KeyID(); kid != "" {
token.Header["kid"] = kid
}
}
// this serializes the aud claim to a string
jwt.MarshalSingleStringAsArray = false
signingKey, err := conf.GetSigningKey(signingJwk)
if err != nil {
return "", err
}
signed, err := token.SignedString(signingKey)
if err != nil {
return "", err
}
return signed, nil
}
26 changes: 1 addition & 25 deletions internal/api/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,6 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user
IsAnonymous: user.IsAnonymous,
}

var token *jwt.Token
var gotrueClaims jwt.Claims = claims
if config.Hook.CustomAccessToken.Enabled {
input := hooks.CustomAccessTokenInput{
Expand All @@ -367,30 +366,7 @@ func (a *API) generateAccessToken(r *http.Request, tx *storage.Connection, user
gotrueClaims = jwt.MapClaims(output.Claims)
}

signingJwk, err := conf.GetSigningJwk(&config.JWT)
if err != nil {
return "", 0, err
}

signingMethod := conf.GetSigningAlg(signingJwk)
token = jwt.NewWithClaims(signingMethod, gotrueClaims)
if token.Header == nil {
token.Header = make(map[string]interface{})
}

if _, ok := token.Header["kid"]; !ok {
if kid := signingJwk.KeyID(); kid != "" {
token.Header["kid"] = kid
}
}

// this serializes the aud claim to a string
jwt.MarshalSingleStringAsArray = false
signingKey, err := conf.GetSigningKey(signingJwk)
if err != nil {
return "", 0, err
}
signed, err := token.SignedString(signingKey)
signed, err := signJwt(&config.JWT, gotrueClaims)
if err != nil {
return "", 0, err
}
Expand Down

0 comments on commit 5f7ac2a

Please sign in to comment.