From 63402dee7604141118fead91491abe6763150f1c Mon Sep 17 00:00:00 2001 From: Flori <40140792+fl0lli@users.noreply.github.com> Date: Fri, 11 Jun 2021 12:29:08 +0200 Subject: [PATCH] feat: add custom claims to top-level JWT payload (#2545) Closes #1974 Co-authored-by: hackerman <3372410+aeneasr@users.noreply.github.com> --- docs/docs/advanced.md | 43 ++++++ driver/config/provider.go | 5 + oauth2/handler.go | 7 +- oauth2/session.go | 36 ++++- oauth2/session_custom_claims_test.go | 217 +++++++++++++++++++++++++++ spec/config.json | 14 ++ x/oauth2cors/cors.go | 2 +- 7 files changed, 316 insertions(+), 8 deletions(-) create mode 100644 oauth2/session_custom_claims_test.go diff --git a/docs/docs/advanced.md b/docs/docs/advanced.md index 47084302db1..5cc7c04fe2e 100644 --- a/docs/docs/advanced.md +++ b/docs/docs/advanced.md @@ -141,6 +141,49 @@ library (e.g. [node-jwks-rsa](https://github.com/auth0/node-jwks-rsa)) to `http://ory-hydra-public-api/.well-known/jwks.json`. All necessary keys are available there. +#### Adding custom claims top-level to the Access Token + +Assume you want to add custom claims to the access token with the following +code: + +```typescript +let session: ConsentRequestSession = { + access_token: { + foo: 'bar' + } +} +``` + +Then part of the resulting access token will look like this: + +```json +{ + "ext": { + "foo": "bar" + } +} +``` + +If you instead want "foo" to be added top-level in the access token, you need to +set the configuration flag `oauth2.allowed_top_level_claims` like described in +[the reference Configuration](https://www.ory.sh/hydra/docs/reference/configuration). + +Note: Any user defined allowed top level claim may not override standardized +access token claim names. + +Configuring Hydra to allow "foo" as a top-level claim will result in the +following access token part (allowed claims get mirrored, for backwards +compatibility): + +```json +{ + "foo": "bar", + "ext": { + "foo": "bar" + } +} +``` + ### OAuth 2.0 Client Authentication with private/public keypairs ORY Hydra supports OAuth 2.0 Client Authentication with RSA and ECDSA diff --git a/driver/config/provider.go b/driver/config/provider.go index 56cbb88b7ce..e8bff7ef09e 100644 --- a/driver/config/provider.go +++ b/driver/config/provider.go @@ -64,6 +64,7 @@ const ( KeyExposeOAuth2Debug = "oauth2.expose_internal_errors" KeyOAuth2LegacyErrors = "oauth2.include_legacy_error_fields" KeyExcludeNotBeforeClaim = "oauth2.exclude_not_before_claim" + KeyAllowedTopLevelClaims = "oauth2.allowed_top_level_claims" ) const DSNMemory = "memory" @@ -130,6 +131,10 @@ func (p *Provider) IsUsingJWTAsAccessTokens() bool { return p.AccessTokenStrategy() != "opaque" } +func (p *Provider) AllowedTopLevelClaims() []string { + return stringslice.Unique(p.p.Strings(KeyAllowedTopLevelClaims)) +} + func (p *Provider) SubjectTypesSupported() []string { types := stringslice.Filter( p.p.StringsF(KeySubjectTypesSupported, []string{"public"}), diff --git a/oauth2/handler.go b/oauth2/handler.go index 512a18fe01a..8b9fd522e15 100644 --- a/oauth2/handler.go +++ b/oauth2/handler.go @@ -282,7 +282,7 @@ func (h *Handler) WellKnownHandler(w http.ResponseWriter, r *http.Request) { // 401: genericError // 500: genericError func (h *Handler) UserinfoHandler(w http.ResponseWriter, r *http.Request) { - session := NewSession("") + session := NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims()) tokenType, ar, err := h.r.OAuth2Provider().IntrospectToken(r.Context(), fosite.AccessTokenFromRequest(r), fosite.AccessToken, session) if err != nil { rfcerr := fosite.ErrorToRFC6749Error(err) @@ -415,7 +415,7 @@ func (h *Handler) RevocationHandler(w http.ResponseWriter, r *http.Request) { // 401: genericError // 500: genericError func (h *Handler) IntrospectHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) { - var session = NewSession("") + var session = NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims()) var ctx = r.Context() if r.Method != "POST" { @@ -576,7 +576,7 @@ func (h *Handler) FlushHandler(w http.ResponseWriter, r *http.Request, _ httprou // 400: genericError // 500: genericError func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request) { - var session = NewSession("") + var session = NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims()) var ctx = r.Context() accessRequest, err := h.r.OAuth2Provider().NewAccessRequest(ctx, r, session) @@ -748,6 +748,7 @@ func (h *Handler) AuthHandler(w http.ResponseWriter, r *http.Request, _ httprout ClientID: authorizeRequest.GetClient().GetID(), ConsentChallenge: session.ID, ExcludeNotBeforeClaim: h.c.ExcludeNotBeforeClaim(), + AllowedTopLevelClaims: h.c.AllowedTopLevelClaims(), }) if err != nil { x.LogError(r, err, h.r.Logger()) diff --git a/oauth2/session.go b/oauth2/session.go index 2cae109d747..5ce4c03f1a1 100644 --- a/oauth2/session.go +++ b/oauth2/session.go @@ -28,6 +28,8 @@ import ( "github.com/ory/fosite" "github.com/ory/fosite/handler/openid" "github.com/ory/fosite/token/jwt" + + "github.com/ory/x/stringslice" ) type Session struct { @@ -37,24 +39,50 @@ type Session struct { ClientID string ConsentChallenge string ExcludeNotBeforeClaim bool + AllowedTopLevelClaims []string } func NewSession(subject string) *Session { + return NewSessionWithCustomClaims(subject, nil) +} + +func NewSessionWithCustomClaims(subject string, allowedTopLevelClaims []string) *Session { return &Session{ DefaultSession: &openid.DefaultSession{ Claims: new(jwt.IDTokenClaims), Headers: new(jwt.Headers), Subject: subject, }, - Extra: map[string]interface{}{}, + Extra: map[string]interface{}{}, + AllowedTopLevelClaims: allowedTopLevelClaims, } } func (s *Session) GetJWTClaims() jwt.JWTClaimsContainer { + //a slice of claims that are reserved and should not be overridden + var reservedClaims = []string{"iss", "sub", "aud", "exp", "nbf", "iat", "jti", "client_id", "scp", "ext"} + + //remove any reserved claims from the custom claims + allowedClaimsFromConfigWithoutReserved := stringslice.Filter(s.AllowedTopLevelClaims, func(s string) bool { + return stringslice.Has(reservedClaims, s) + }) + + //our new extra map which will be added to the jwt + var topLevelExtraWithMirrorExt = map[string]interface{}{} + + //setting every allowed claim top level in jwt with respective value + for _, allowedClaim := range allowedClaimsFromConfigWithoutReserved { + topLevelExtraWithMirrorExt[allowedClaim] = s.Extra[allowedClaim] + } + + //for every other claim that was already reserved and for mirroring, add original extra under "ext" + topLevelExtraWithMirrorExt["ext"] = s.Extra + claims := &jwt.JWTClaims{ - Subject: s.Subject, - Issuer: s.DefaultSession.Claims.Issuer, - Extra: map[string]interface{}{"ext": s.Extra}, + Subject: s.Subject, + Issuer: s.DefaultSession.Claims.Issuer, + //set our custom extra map as claims.Extra + Extra: topLevelExtraWithMirrorExt, ExpiresAt: s.GetExpiresAt(fosite.AccessToken), IssuedAt: time.Now(), diff --git a/oauth2/session_custom_claims_test.go b/oauth2/session_custom_claims_test.go new file mode 100644 index 00000000000..0c353593395 --- /dev/null +++ b/oauth2/session_custom_claims_test.go @@ -0,0 +1,217 @@ +package oauth2_test + +import ( + "testing" + + "github.com/ory/fosite/handler/openid" + "github.com/ory/fosite/token/jwt" + + "github.com/ory/hydra/driver/config" + "github.com/ory/hydra/internal" + "github.com/ory/hydra/oauth2" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func createSessionWithCustomClaims(extra map[string]interface{}, allowedTopLevelClaims []string) oauth2.Session { + session := &oauth2.Session{ + DefaultSession: &openid.DefaultSession{ + Claims: &jwt.IDTokenClaims{ + Subject: "alice", + Issuer: "hydra.localhost", + }, + Headers: new(jwt.Headers), + Subject: "alice", + }, + Extra: extra, + AllowedTopLevelClaims: allowedTopLevelClaims, + } + return *session +} + +func TestCustomClaimsInSession(t *testing.T) { + c := internal.NewConfigurationWithDefaults() + + t.Run("no_custom_claims", func(t *testing.T) { + c.MustSet(config.KeyAllowedTopLevelClaims, []string{}) + + session := createSessionWithCustomClaims(nil, c.AllowedTopLevelClaims()) + claims := session.GetJWTClaims().ToMapClaims() + + assert.EqualValues(t, "alice", claims["sub"]) + assert.NotEqual(t, "another-alice", claims["sub"]) + + require.Contains(t, claims, "iss") + assert.EqualValues(t, "hydra.localhost", claims["iss"]) + + assert.Empty(t, claims["ext"]) + }) + t.Run("custom_claim_gets_mirrored", func(t *testing.T) { + c.MustSet(config.KeyAllowedTopLevelClaims, []string{"foo"}) + extra := map[string]interface{}{"foo": "bar"} + + session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims()) + + claims := session.GetJWTClaims().ToMapClaims() + + assert.EqualValues(t, "alice", claims["sub"]) + assert.NotEqual(t, "another-alice", claims["sub"]) + + require.Contains(t, claims, "iss") + assert.EqualValues(t, "hydra.localhost", claims["iss"]) + + require.Contains(t, claims, "foo") + assert.EqualValues(t, "bar", claims["foo"]) + + require.Contains(t, claims, "ext") + extClaims, ok := claims["ext"].(map[string]interface{}) + require.True(t, ok) + + require.Contains(t, extClaims, "foo") + assert.EqualValues(t, "bar", extClaims["foo"]) + }) + t.Run("only_non_reserved_claims_get_mirrored", func(t *testing.T) { + c.MustSet(config.KeyAllowedTopLevelClaims, []string{"foo", "iss", "sub"}) + extra := map[string]interface{}{"foo": "bar", "iss": "hydra.remote", "sub": "another-alice"} + + session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims()) + + claims := session.GetJWTClaims().ToMapClaims() + + assert.EqualValues(t, "alice", claims["sub"]) + assert.NotEqual(t, "another-alice", claims["sub"]) + + require.Contains(t, claims, "iss") + assert.EqualValues(t, "hydra.localhost", claims["iss"]) + assert.NotEqual(t, "hydra.remote", claims["iss"]) + + require.Contains(t, claims, "foo") + assert.EqualValues(t, "bar", claims["foo"]) + + require.Contains(t, claims, "ext") + extClaims, ok := claims["ext"].(map[string]interface{}) + require.True(t, ok) + + require.Contains(t, extClaims, "foo") + assert.EqualValues(t, "bar", extClaims["foo"]) + + require.Contains(t, extClaims, "iss") + assert.EqualValues(t, "hydra.remote", extClaims["iss"]) + + require.Contains(t, extClaims, "sub") + assert.EqualValues(t, "another-alice", extClaims["sub"]) + }) + t.Run("no_custom_claims_in_config", func(t *testing.T) { + c.MustSet(config.KeyAllowedTopLevelClaims, []string{}) + extra := map[string]interface{}{"foo": "bar", "iss": "hydra.remote", "sub": "another-alice"} + + session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims()) + + claims := session.GetJWTClaims().ToMapClaims() + + assert.EqualValues(t, "alice", claims["sub"]) + assert.NotEqual(t, "another-alice", claims["sub"]) + + require.Contains(t, claims, "iss") + assert.EqualValues(t, "hydra.localhost", claims["iss"]) + + assert.NotContains(t, claims, "foo") + + require.Contains(t, claims, "ext") + extClaims, ok := claims["ext"].(map[string]interface{}) + require.True(t, ok) + + require.Contains(t, extClaims, "foo") + assert.EqualValues(t, "bar", extClaims["foo"]) + + require.Contains(t, extClaims, "sub") + assert.EqualValues(t, "another-alice", extClaims["sub"]) + + require.Contains(t, extClaims, "iss") + assert.EqualValues(t, "hydra.remote", extClaims["iss"]) + }) + t.Run("more_config_claims_than_given", func(t *testing.T) { + c.MustSet(config.KeyAllowedTopLevelClaims, []string{"foo", "baz", "bar", "iss"}) + extra := map[string]interface{}{"foo": "foo_value", "sub": "another-alice"} + + session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims()) + + claims := session.GetJWTClaims().ToMapClaims() + + assert.EqualValues(t, "alice", claims["sub"]) + assert.NotEqual(t, "another-alice", claims["sub"]) + + require.Contains(t, claims, "iss") + assert.EqualValues(t, "hydra.localhost", claims["iss"]) + assert.NotEqual(t, "hydra.remote", claims["iss"]) + + require.Contains(t, claims, "foo") + assert.EqualValues(t, "foo_value", claims["foo"]) + + require.Contains(t, claims, "ext") + extClaims, ok := claims["ext"].(map[string]interface{}) + require.True(t, ok) + + require.Contains(t, extClaims, "foo") + assert.EqualValues(t, "foo_value", extClaims["foo"]) + + require.Contains(t, extClaims, "sub") + assert.EqualValues(t, "another-alice", extClaims["sub"]) + }) + t.Run("less_config_claims_than_given", func(t *testing.T) { + c.MustSet(config.KeyAllowedTopLevelClaims, []string{"foo", "sub"}) + extra := map[string]interface{}{"foo": "foo_value", "bar": "bar_value", "baz": "baz_value", "sub": "another-alice"} + + session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims()) + + claims := session.GetJWTClaims().ToMapClaims() + + assert.EqualValues(t, "alice", claims["sub"]) + assert.NotEqual(t, "another-alice", claims["sub"]) + + require.Contains(t, claims, "iss") + assert.EqualValues(t, "hydra.localhost", claims["iss"]) + + require.Contains(t, claims, "foo") + assert.EqualValues(t, "foo_value", claims["foo"]) + + assert.NotContains(t, claims, "bar") + assert.NotContains(t, claims, "baz") + + require.Contains(t, claims, "ext") + extClaims, ok := claims["ext"].(map[string]interface{}) + require.True(t, ok) + + require.Contains(t, extClaims, "foo") + assert.EqualValues(t, "foo_value", extClaims["foo"]) + + require.Contains(t, extClaims, "sub") + assert.EqualValues(t, "another-alice", extClaims["sub"]) + }) + t.Run("config_claims_contain_reserved_claims", func(t *testing.T) { + c.MustSet(config.KeyAllowedTopLevelClaims, []string{"iss", "sub"}) + extra := map[string]interface{}{"iss": "hydra.remote", "sub": "another-alice"} + + session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims()) + + claims := session.GetJWTClaims().ToMapClaims() + + assert.EqualValues(t, "alice", claims["sub"]) + assert.NotEqual(t, "another-alice", claims["sub"]) + + require.Contains(t, claims, "iss") + assert.EqualValues(t, "hydra.localhost", claims["iss"]) + assert.NotEqualValues(t, "hydra.remote", claims["iss"]) + + require.Contains(t, claims, "ext") + extClaims, ok := claims["ext"].(map[string]interface{}) + require.True(t, ok) + + require.Contains(t, extClaims, "sub") + assert.EqualValues(t, "another-alice", extClaims["sub"]) + + require.Contains(t, extClaims, "iss") + assert.EqualValues(t, "hydra.remote", extClaims["iss"]) + }) +} diff --git a/spec/config.json b/spec/config.json index c4086d028cb..a9dbfc331e5 100644 --- a/spec/config.json +++ b/spec/config.json @@ -800,6 +800,20 @@ true ] }, + "allowed_top_level_claims": { + "type": "array", + "description": "A list of custom claims which are allowed to be added top level to the Access Token. They cannot override reserved claims.", + "items": { + "type": "string" + }, + "examples": [ + [ + "username", + "email", + "user_uuid" + ] + ] + }, "hashers": { "type": "object", "additionalProperties": false, diff --git a/x/oauth2cors/cors.go b/x/oauth2cors/cors.go index 4e76feaca8c..efde2b8cfb2 100644 --- a/x/oauth2cors/cors.go +++ b/x/oauth2cors/cors.go @@ -98,7 +98,7 @@ func Middleware(reg interface { return false } - session := oauth2.NewSession("") + session := oauth2.NewSessionWithCustomClaims("", reg.Config().AllowedTopLevelClaims()) _, ar, err := reg.OAuth2Provider().IntrospectToken(context.Background(), token, fosite.AccessToken, session) if err != nil { return false