Skip to content

Commit

Permalink
feat: add custom claims to top-level JWT payload (#2545)
Browse files Browse the repository at this point in the history
Closes #1974

Co-authored-by: hackerman <3372410+aeneasr@users.noreply.github.com>
  • Loading branch information
fl0lli and aeneasr authored Jun 11, 2021
1 parent f74fe90 commit 63402de
Show file tree
Hide file tree
Showing 7 changed files with 316 additions and 8 deletions.
43 changes: 43 additions & 0 deletions docs/docs/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,49 @@ library (e.g. [node-jwks-rsa](https://github.com/auth0/node-jwks-rsa)) to
`http://ory-hydra-public-api/.well-known/jwks.json`. All necessary keys are
available there.

#### Adding custom claims top-level to the Access Token

Assume you want to add custom claims to the access token with the following
code:

```typescript
let session: ConsentRequestSession = {
access_token: {
foo: 'bar'
}
}
```

Then part of the resulting access token will look like this:

```json
{
"ext": {
"foo": "bar"
}
}
```

If you instead want "foo" to be added top-level in the access token, you need to
set the configuration flag `oauth2.allowed_top_level_claims` like described in
[the reference Configuration](https://www.ory.sh/hydra/docs/reference/configuration).

Note: Any user defined allowed top level claim may not override standardized
access token claim names.

Configuring Hydra to allow "foo" as a top-level claim will result in the
following access token part (allowed claims get mirrored, for backwards
compatibility):

```json
{
"foo": "bar",
"ext": {
"foo": "bar"
}
}
```

### OAuth 2.0 Client Authentication with private/public keypairs

ORY Hydra supports OAuth 2.0 Client Authentication with RSA and ECDSA
Expand Down
5 changes: 5 additions & 0 deletions driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ const (
KeyExposeOAuth2Debug = "oauth2.expose_internal_errors"
KeyOAuth2LegacyErrors = "oauth2.include_legacy_error_fields"
KeyExcludeNotBeforeClaim = "oauth2.exclude_not_before_claim"
KeyAllowedTopLevelClaims = "oauth2.allowed_top_level_claims"
)

const DSNMemory = "memory"
Expand Down Expand Up @@ -130,6 +131,10 @@ func (p *Provider) IsUsingJWTAsAccessTokens() bool {
return p.AccessTokenStrategy() != "opaque"
}

func (p *Provider) AllowedTopLevelClaims() []string {
return stringslice.Unique(p.p.Strings(KeyAllowedTopLevelClaims))
}

func (p *Provider) SubjectTypesSupported() []string {
types := stringslice.Filter(
p.p.StringsF(KeySubjectTypesSupported, []string{"public"}),
Expand Down
7 changes: 4 additions & 3 deletions oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (h *Handler) WellKnownHandler(w http.ResponseWriter, r *http.Request) {
// 401: genericError
// 500: genericError
func (h *Handler) UserinfoHandler(w http.ResponseWriter, r *http.Request) {
session := NewSession("")
session := NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims())
tokenType, ar, err := h.r.OAuth2Provider().IntrospectToken(r.Context(), fosite.AccessTokenFromRequest(r), fosite.AccessToken, session)
if err != nil {
rfcerr := fosite.ErrorToRFC6749Error(err)
Expand Down Expand Up @@ -415,7 +415,7 @@ func (h *Handler) RevocationHandler(w http.ResponseWriter, r *http.Request) {
// 401: genericError
// 500: genericError
func (h *Handler) IntrospectHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
var session = NewSession("")
var session = NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims())
var ctx = r.Context()

if r.Method != "POST" {
Expand Down Expand Up @@ -576,7 +576,7 @@ func (h *Handler) FlushHandler(w http.ResponseWriter, r *http.Request, _ httprou
// 400: genericError
// 500: genericError
func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request) {
var session = NewSession("")
var session = NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims())
var ctx = r.Context()

accessRequest, err := h.r.OAuth2Provider().NewAccessRequest(ctx, r, session)
Expand Down Expand Up @@ -748,6 +748,7 @@ func (h *Handler) AuthHandler(w http.ResponseWriter, r *http.Request, _ httprout
ClientID: authorizeRequest.GetClient().GetID(),
ConsentChallenge: session.ID,
ExcludeNotBeforeClaim: h.c.ExcludeNotBeforeClaim(),
AllowedTopLevelClaims: h.c.AllowedTopLevelClaims(),
})
if err != nil {
x.LogError(r, err, h.r.Logger())
Expand Down
36 changes: 32 additions & 4 deletions oauth2/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt"

"github.com/ory/x/stringslice"
)

type Session struct {
Expand All @@ -37,24 +39,50 @@ type Session struct {
ClientID string
ConsentChallenge string
ExcludeNotBeforeClaim bool
AllowedTopLevelClaims []string
}

func NewSession(subject string) *Session {
return NewSessionWithCustomClaims(subject, nil)
}

func NewSessionWithCustomClaims(subject string, allowedTopLevelClaims []string) *Session {
return &Session{
DefaultSession: &openid.DefaultSession{
Claims: new(jwt.IDTokenClaims),
Headers: new(jwt.Headers),
Subject: subject,
},
Extra: map[string]interface{}{},
Extra: map[string]interface{}{},
AllowedTopLevelClaims: allowedTopLevelClaims,
}
}

func (s *Session) GetJWTClaims() jwt.JWTClaimsContainer {
//a slice of claims that are reserved and should not be overridden
var reservedClaims = []string{"iss", "sub", "aud", "exp", "nbf", "iat", "jti", "client_id", "scp", "ext"}

//remove any reserved claims from the custom claims
allowedClaimsFromConfigWithoutReserved := stringslice.Filter(s.AllowedTopLevelClaims, func(s string) bool {
return stringslice.Has(reservedClaims, s)
})

//our new extra map which will be added to the jwt
var topLevelExtraWithMirrorExt = map[string]interface{}{}

//setting every allowed claim top level in jwt with respective value
for _, allowedClaim := range allowedClaimsFromConfigWithoutReserved {
topLevelExtraWithMirrorExt[allowedClaim] = s.Extra[allowedClaim]
}

//for every other claim that was already reserved and for mirroring, add original extra under "ext"
topLevelExtraWithMirrorExt["ext"] = s.Extra

claims := &jwt.JWTClaims{
Subject: s.Subject,
Issuer: s.DefaultSession.Claims.Issuer,
Extra: map[string]interface{}{"ext": s.Extra},
Subject: s.Subject,
Issuer: s.DefaultSession.Claims.Issuer,
//set our custom extra map as claims.Extra
Extra: topLevelExtraWithMirrorExt,
ExpiresAt: s.GetExpiresAt(fosite.AccessToken),
IssuedAt: time.Now(),

Expand Down
217 changes: 217 additions & 0 deletions oauth2/session_custom_claims_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
package oauth2_test

import (
"testing"

"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt"

"github.com/ory/hydra/driver/config"
"github.com/ory/hydra/internal"
"github.com/ory/hydra/oauth2"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func createSessionWithCustomClaims(extra map[string]interface{}, allowedTopLevelClaims []string) oauth2.Session {
session := &oauth2.Session{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: "alice",
Issuer: "hydra.localhost",
},
Headers: new(jwt.Headers),
Subject: "alice",
},
Extra: extra,
AllowedTopLevelClaims: allowedTopLevelClaims,
}
return *session
}

func TestCustomClaimsInSession(t *testing.T) {
c := internal.NewConfigurationWithDefaults()

t.Run("no_custom_claims", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{})

session := createSessionWithCustomClaims(nil, c.AllowedTopLevelClaims())
claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])

assert.Empty(t, claims["ext"])
})
t.Run("custom_claim_gets_mirrored", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{"foo"})
extra := map[string]interface{}{"foo": "bar"}

session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims())

claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])

require.Contains(t, claims, "foo")
assert.EqualValues(t, "bar", claims["foo"])

require.Contains(t, claims, "ext")
extClaims, ok := claims["ext"].(map[string]interface{})
require.True(t, ok)

require.Contains(t, extClaims, "foo")
assert.EqualValues(t, "bar", extClaims["foo"])
})
t.Run("only_non_reserved_claims_get_mirrored", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{"foo", "iss", "sub"})
extra := map[string]interface{}{"foo": "bar", "iss": "hydra.remote", "sub": "another-alice"}

session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims())

claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])
assert.NotEqual(t, "hydra.remote", claims["iss"])

require.Contains(t, claims, "foo")
assert.EqualValues(t, "bar", claims["foo"])

require.Contains(t, claims, "ext")
extClaims, ok := claims["ext"].(map[string]interface{})
require.True(t, ok)

require.Contains(t, extClaims, "foo")
assert.EqualValues(t, "bar", extClaims["foo"])

require.Contains(t, extClaims, "iss")
assert.EqualValues(t, "hydra.remote", extClaims["iss"])

require.Contains(t, extClaims, "sub")
assert.EqualValues(t, "another-alice", extClaims["sub"])
})
t.Run("no_custom_claims_in_config", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{})
extra := map[string]interface{}{"foo": "bar", "iss": "hydra.remote", "sub": "another-alice"}

session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims())

claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])

assert.NotContains(t, claims, "foo")

require.Contains(t, claims, "ext")
extClaims, ok := claims["ext"].(map[string]interface{})
require.True(t, ok)

require.Contains(t, extClaims, "foo")
assert.EqualValues(t, "bar", extClaims["foo"])

require.Contains(t, extClaims, "sub")
assert.EqualValues(t, "another-alice", extClaims["sub"])

require.Contains(t, extClaims, "iss")
assert.EqualValues(t, "hydra.remote", extClaims["iss"])
})
t.Run("more_config_claims_than_given", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{"foo", "baz", "bar", "iss"})
extra := map[string]interface{}{"foo": "foo_value", "sub": "another-alice"}

session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims())

claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])
assert.NotEqual(t, "hydra.remote", claims["iss"])

require.Contains(t, claims, "foo")
assert.EqualValues(t, "foo_value", claims["foo"])

require.Contains(t, claims, "ext")
extClaims, ok := claims["ext"].(map[string]interface{})
require.True(t, ok)

require.Contains(t, extClaims, "foo")
assert.EqualValues(t, "foo_value", extClaims["foo"])

require.Contains(t, extClaims, "sub")
assert.EqualValues(t, "another-alice", extClaims["sub"])
})
t.Run("less_config_claims_than_given", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{"foo", "sub"})
extra := map[string]interface{}{"foo": "foo_value", "bar": "bar_value", "baz": "baz_value", "sub": "another-alice"}

session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims())

claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])

require.Contains(t, claims, "foo")
assert.EqualValues(t, "foo_value", claims["foo"])

assert.NotContains(t, claims, "bar")
assert.NotContains(t, claims, "baz")

require.Contains(t, claims, "ext")
extClaims, ok := claims["ext"].(map[string]interface{})
require.True(t, ok)

require.Contains(t, extClaims, "foo")
assert.EqualValues(t, "foo_value", extClaims["foo"])

require.Contains(t, extClaims, "sub")
assert.EqualValues(t, "another-alice", extClaims["sub"])
})
t.Run("config_claims_contain_reserved_claims", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{"iss", "sub"})
extra := map[string]interface{}{"iss": "hydra.remote", "sub": "another-alice"}

session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims())

claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])
assert.NotEqualValues(t, "hydra.remote", claims["iss"])

require.Contains(t, claims, "ext")
extClaims, ok := claims["ext"].(map[string]interface{})
require.True(t, ok)

require.Contains(t, extClaims, "sub")
assert.EqualValues(t, "another-alice", extClaims["sub"])

require.Contains(t, extClaims, "iss")
assert.EqualValues(t, "hydra.remote", extClaims["iss"])
})
}
Loading

0 comments on commit 63402de

Please sign in to comment.