Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: #1974 changes for mirroring custom claims #2545

Merged
merged 10 commits into from
Jun 11, 2021
43 changes: 43 additions & 0 deletions docs/docs/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,49 @@ library (e.g. [node-jwks-rsa](https://github.com/auth0/node-jwks-rsa)) to
`http://ory-hydra-public-api/.well-known/jwks.json`. All necessary keys are
available there.

#### Adding custom claims top-level to the Access Token

Assume you want to add custom claims to the access token with the following
code:

```typescript
let session: ConsentRequestSession = {
access_token: {
foo: 'bar'
}
}
```

Then part of the resulting access token will look like this:

```json
{
"ext": {
"foo": "bar"
}
}
```

If you instead want "foo" to be added top-level in the access token, you need to
set the configuration flag `oauth2.allowed_top_level_claims` like described in
[the reference Configuration](https://www.ory.sh/hydra/docs/reference/configuration).

Note: Any user defined allowed top level claim may not override standardized
access token claim names.

Configuring Hydra to allow "foo" as a top-level claim will result in the
following access token part (allowed claims get mirrored, for backwards
compatibility):

```json
{
"foo": "bar",
"ext": {
"foo": "bar"
}
}
```

### OAuth 2.0 Client Authentication with private/public keypairs

ORY Hydra supports OAuth 2.0 Client Authentication with RSA and ECDSA
Expand Down
5 changes: 5 additions & 0 deletions driver/config/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ const (
KeyExposeOAuth2Debug = "oauth2.expose_internal_errors"
KeyOAuth2LegacyErrors = "oauth2.include_legacy_error_fields"
KeyExcludeNotBeforeClaim = "oauth2.exclude_not_before_claim"
KeyAllowedTopLevelClaims = "oauth2.allowed_top_level_claims"
)

const DSNMemory = "memory"
Expand Down Expand Up @@ -130,6 +131,10 @@ func (p *Provider) IsUsingJWTAsAccessTokens() bool {
return p.AccessTokenStrategy() != "opaque"
}

func (p *Provider) AllowedTopLevelClaims() []string {
return stringslice.Unique(p.p.Strings(KeyAllowedTopLevelClaims))
}

func (p *Provider) SubjectTypesSupported() []string {
types := stringslice.Filter(
p.p.StringsF(KeySubjectTypesSupported, []string{"public"}),
Expand Down
7 changes: 4 additions & 3 deletions oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ func (h *Handler) WellKnownHandler(w http.ResponseWriter, r *http.Request) {
// 401: genericError
// 500: genericError
func (h *Handler) UserinfoHandler(w http.ResponseWriter, r *http.Request) {
session := NewSession("")
session := NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims())
tokenType, ar, err := h.r.OAuth2Provider().IntrospectToken(r.Context(), fosite.AccessTokenFromRequest(r), fosite.AccessToken, session)
if err != nil {
rfcerr := fosite.ErrorToRFC6749Error(err)
Expand Down Expand Up @@ -415,7 +415,7 @@ func (h *Handler) RevocationHandler(w http.ResponseWriter, r *http.Request) {
// 401: genericError
// 500: genericError
func (h *Handler) IntrospectHandler(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
var session = NewSession("")
var session = NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims())
var ctx = r.Context()

if r.Method != "POST" {
Expand Down Expand Up @@ -576,7 +576,7 @@ func (h *Handler) FlushHandler(w http.ResponseWriter, r *http.Request, _ httprou
// 400: genericError
// 500: genericError
func (h *Handler) TokenHandler(w http.ResponseWriter, r *http.Request) {
var session = NewSession("")
var session = NewSessionWithCustomClaims("", h.c.AllowedTopLevelClaims())
var ctx = r.Context()

accessRequest, err := h.r.OAuth2Provider().NewAccessRequest(ctx, r, session)
Expand Down Expand Up @@ -748,6 +748,7 @@ func (h *Handler) AuthHandler(w http.ResponseWriter, r *http.Request, _ httprout
ClientID: authorizeRequest.GetClient().GetID(),
ConsentChallenge: session.ID,
ExcludeNotBeforeClaim: h.c.ExcludeNotBeforeClaim(),
AllowedTopLevelClaims: h.c.AllowedTopLevelClaims(),
})
if err != nil {
x.LogError(r, err, h.r.Logger())
Expand Down
36 changes: 32 additions & 4 deletions oauth2/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ import (
"github.com/ory/fosite"
"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt"

"github.com/ory/x/stringslice"
)

type Session struct {
Expand All @@ -37,24 +39,50 @@ type Session struct {
ClientID string
ConsentChallenge string
ExcludeNotBeforeClaim bool
AllowedTopLevelClaims []string
}

func NewSession(subject string) *Session {
return NewSessionWithCustomClaims(subject, nil)
}

func NewSessionWithCustomClaims(subject string, allowedTopLevelClaims []string) *Session {
return &Session{
DefaultSession: &openid.DefaultSession{
Claims: new(jwt.IDTokenClaims),
Headers: new(jwt.Headers),
Subject: subject,
},
Extra: map[string]interface{}{},
Extra: map[string]interface{}{},
AllowedTopLevelClaims: allowedTopLevelClaims,
}
}

func (s *Session) GetJWTClaims() jwt.JWTClaimsContainer {
//a slice of claims that are reserved and should not be overridden
var reservedClaims = []string{"iss", "sub", "aud", "exp", "nbf", "iat", "jti", "client_id", "scp", "ext"}

//remove any reserved claims from the custom claims
allowedClaimsFromConfigWithoutReserved := stringslice.Filter(s.AllowedTopLevelClaims, func(s string) bool {
return stringslice.Has(reservedClaims, s)
})

//our new extra map which will be added to the jwt
var topLevelExtraWithMirrorExt = map[string]interface{}{}

//setting every allowed claim top level in jwt with respective value
for _, allowedClaim := range allowedClaimsFromConfigWithoutReserved {
topLevelExtraWithMirrorExt[allowedClaim] = s.Extra[allowedClaim]
}

//for every other claim that was already reserved and for mirroring, add original extra under "ext"
topLevelExtraWithMirrorExt["ext"] = s.Extra

claims := &jwt.JWTClaims{
Subject: s.Subject,
Issuer: s.DefaultSession.Claims.Issuer,
Extra: map[string]interface{}{"ext": s.Extra},
Subject: s.Subject,
Issuer: s.DefaultSession.Claims.Issuer,
//set our custom extra map as claims.Extra
Extra: topLevelExtraWithMirrorExt,
ExpiresAt: s.GetExpiresAt(fosite.AccessToken),
IssuedAt: time.Now(),

Expand Down
217 changes: 217 additions & 0 deletions oauth2/session_custom_claims_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
package oauth2_test

import (
"testing"

"github.com/ory/fosite/handler/openid"
"github.com/ory/fosite/token/jwt"

"github.com/ory/hydra/driver/config"
"github.com/ory/hydra/internal"
"github.com/ory/hydra/oauth2"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func createSessionWithCustomClaims(extra map[string]interface{}, allowedTopLevelClaims []string) oauth2.Session {
session := &oauth2.Session{
DefaultSession: &openid.DefaultSession{
Claims: &jwt.IDTokenClaims{
Subject: "alice",
Issuer: "hydra.localhost",
},
Headers: new(jwt.Headers),
Subject: "alice",
},
Extra: extra,
AllowedTopLevelClaims: allowedTopLevelClaims,
}
return *session
}

func TestCustomClaimsInSession(t *testing.T) {
fl0lli marked this conversation as resolved.
Show resolved Hide resolved
c := internal.NewConfigurationWithDefaults()

t.Run("no_custom_claims", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{})

session := createSessionWithCustomClaims(nil, c.AllowedTopLevelClaims())
claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])

assert.Empty(t, claims["ext"])
})
t.Run("custom_claim_gets_mirrored", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{"foo"})
extra := map[string]interface{}{"foo": "bar"}

session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims())

claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])

require.Contains(t, claims, "foo")
assert.EqualValues(t, "bar", claims["foo"])

require.Contains(t, claims, "ext")
extClaims, ok := claims["ext"].(map[string]interface{})
require.True(t, ok)

require.Contains(t, extClaims, "foo")
assert.EqualValues(t, "bar", extClaims["foo"])
})
t.Run("only_non_reserved_claims_get_mirrored", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{"foo", "iss", "sub"})
extra := map[string]interface{}{"foo": "bar", "iss": "hydra.remote", "sub": "another-alice"}

session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims())

claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])
assert.NotEqual(t, "hydra.remote", claims["iss"])

require.Contains(t, claims, "foo")
assert.EqualValues(t, "bar", claims["foo"])

require.Contains(t, claims, "ext")
extClaims, ok := claims["ext"].(map[string]interface{})
require.True(t, ok)

require.Contains(t, extClaims, "foo")
assert.EqualValues(t, "bar", extClaims["foo"])

require.Contains(t, extClaims, "iss")
assert.EqualValues(t, "hydra.remote", extClaims["iss"])

require.Contains(t, extClaims, "sub")
assert.EqualValues(t, "another-alice", extClaims["sub"])
})
t.Run("no_custom_claims_in_config", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{})
extra := map[string]interface{}{"foo": "bar", "iss": "hydra.remote", "sub": "another-alice"}

session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims())

claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])

assert.NotContains(t, claims, "foo")

require.Contains(t, claims, "ext")
extClaims, ok := claims["ext"].(map[string]interface{})
require.True(t, ok)

require.Contains(t, extClaims, "foo")
assert.EqualValues(t, "bar", extClaims["foo"])

require.Contains(t, extClaims, "sub")
assert.EqualValues(t, "another-alice", extClaims["sub"])

require.Contains(t, extClaims, "iss")
assert.EqualValues(t, "hydra.remote", extClaims["iss"])
})
t.Run("more_config_claims_than_given", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{"foo", "baz", "bar", "iss"})
extra := map[string]interface{}{"foo": "foo_value", "sub": "another-alice"}

session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims())

claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])
assert.NotEqual(t, "hydra.remote", claims["iss"])

require.Contains(t, claims, "foo")
assert.EqualValues(t, "foo_value", claims["foo"])

require.Contains(t, claims, "ext")
extClaims, ok := claims["ext"].(map[string]interface{})
require.True(t, ok)

require.Contains(t, extClaims, "foo")
assert.EqualValues(t, "foo_value", extClaims["foo"])

require.Contains(t, extClaims, "sub")
assert.EqualValues(t, "another-alice", extClaims["sub"])
})
t.Run("less_config_claims_than_given", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{"foo", "sub"})
extra := map[string]interface{}{"foo": "foo_value", "bar": "bar_value", "baz": "baz_value", "sub": "another-alice"}

session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims())

claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])

require.Contains(t, claims, "foo")
assert.EqualValues(t, "foo_value", claims["foo"])

assert.NotContains(t, claims, "bar")
assert.NotContains(t, claims, "baz")

require.Contains(t, claims, "ext")
extClaims, ok := claims["ext"].(map[string]interface{})
require.True(t, ok)

require.Contains(t, extClaims, "foo")
assert.EqualValues(t, "foo_value", extClaims["foo"])

require.Contains(t, extClaims, "sub")
assert.EqualValues(t, "another-alice", extClaims["sub"])
})
t.Run("config_claims_contain_reserved_claims", func(t *testing.T) {
c.MustSet(config.KeyAllowedTopLevelClaims, []string{"iss", "sub"})
extra := map[string]interface{}{"iss": "hydra.remote", "sub": "another-alice"}

session := createSessionWithCustomClaims(extra, c.AllowedTopLevelClaims())

claims := session.GetJWTClaims().ToMapClaims()

assert.EqualValues(t, "alice", claims["sub"])
assert.NotEqual(t, "another-alice", claims["sub"])

require.Contains(t, claims, "iss")
assert.EqualValues(t, "hydra.localhost", claims["iss"])
assert.NotEqualValues(t, "hydra.remote", claims["iss"])

require.Contains(t, claims, "ext")
extClaims, ok := claims["ext"].(map[string]interface{})
require.True(t, ok)

require.Contains(t, extClaims, "sub")
assert.EqualValues(t, "another-alice", extClaims["sub"])

require.Contains(t, extClaims, "iss")
assert.EqualValues(t, "hydra.remote", extClaims["iss"])
})
}
Loading