Skip to content

Commit

Permalink
feat(openid): harden id_token validation
Browse files Browse the repository at this point in the history
  • Loading branch information
tronghn committed Aug 15, 2023
1 parent f8d6633 commit c6dad59
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ The following flags are available:
| `log-level` | string | Logging verbosity level. | `info` |
| `metrics-bind-address` | string | Listen address for metrics only. | `127.0.0.1:3001` |
| `openid.acr-values` | string | Space separated string that configures the default security level (`acr_values`) parameter for authorization requests. | |
| `openid.audiences` | strings | List of additional trusted audiences (other than the client_id) for OpenID Connect id_token validation. | |
| `openid.client-id` | string | Client ID for the OpenID client. | |
| `openid.client-jwk` | string | JWK containing the private key for the OpenID client in string format. | |
| `openid.post-logout-redirect-uri` | string | URI for redirecting the user after successful logout at the Identity Provider. | |
Expand Down
3 changes: 3 additions & 0 deletions pkg/config/openid.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (

const (
OpenIDACRValues = "openid.acr-values"
OpenIDAudiences = "openid.audiences"
OpenIDClientID = "openid.client-id"
OpenIDClientJWK = "openid.client-jwk"
OpenIDPostLogoutRedirectURI = "openid.post-logout-redirect-uri"
Expand All @@ -18,6 +19,7 @@ const (

type OpenID struct {
ACRValues string `json:"acr-values"`
Audiences []string `json:"audiences"`
ClientID string `json:"client-id"`
ClientJWK string `json:"client-jwk"`
PostLogoutRedirectURI string `json:"post-logout-redirect-uri"`
Expand All @@ -38,6 +40,7 @@ const (

func openIDFlags() {
flag.String(OpenIDACRValues, "", "Space separated string that configures the default security level (acr_values) parameter for authorization requests.")
flag.StringSlice(OpenIDAudiences, []string{}, "List of additional trusted audiences (other than the client_id) for OpenID Connect id_token validation.")
flag.String(OpenIDClientID, "", "Client ID for the OpenID client.")
flag.String(OpenIDClientJWK, "", "JWK containing the private key for the OpenID client in string format.")
flag.String(OpenIDPostLogoutRedirectURI, "", "URI for redirecting the user after successful logout at the Identity Provider.")
Expand Down
4 changes: 4 additions & 0 deletions pkg/mock/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ func (c *TestClientConfiguration) ACRValues() string {
return c.Config.OpenID.ACRValues
}

func (c *TestClientConfiguration) Audiences() []string {
return c.Config.OpenID.Audiences
}

func (c *TestClientConfiguration) ClientID() string {
return c.Config.OpenID.ClientID
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/openid/config/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

type Client interface {
ACRValues() string
Audiences() []string
ClientID() string
ClientJWK() jwk.Key
PostLogoutRedirectURI() string
Expand All @@ -32,6 +33,10 @@ func (in *client) ACRValues() string {
return in.OpenID.ACRValues
}

func (in *client) Audiences() []string {
return in.OpenID.Audiences
}

func (in *client) ClientID() string {
return in.OpenID.ClientID
}
Expand Down
41 changes: 40 additions & 1 deletion pkg/openid/tokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package openid

import (
"fmt"
"slices"
"strings"
"time"

Expand Down Expand Up @@ -62,8 +63,20 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie) error
clientConfig := cfg.Client()

opts := []jwtlib.ValidateOption{
jwtlib.WithRequiredClaim("iss"),
jwtlib.WithRequiredClaim("sub"),
jwtlib.WithRequiredClaim("aud"),
jwtlib.WithRequiredClaim("exp"),
jwtlib.WithRequiredClaim("iat"),
// OpenID Connect Core 3.1.3.7, step 3.
// The Client MUST validate that the `aud` (audience) Claim contains its `client_id` value registered at the Issuer identified by the `iss` (issuer) Claim as an audience.
// The ID Token MUST be rejected if the ID Token does not list the Client as a valid audience
jwtlib.WithAudience(clientConfig.ClientID()),
// OpenID Connect Core 3.1.3.7, step 11.
// If a nonce value was sent in the Authentication Request, a `nonce` Claim MUST be present and its value checked to verify that it is the same value as the one that was sent in the Authentication Request.
jwtlib.WithClaimValue("nonce", cookie.Nonce),
// OpenID Connect Core 3.1.3.7, step 2.
// The Issuer Identifier for the OpenID Provider (which is typically obtained during Discovery) MUST exactly match the value of the `iss` (issuer) Claim.
jwtlib.WithIssuer(openIDconfig.Issuer()),
jwtlib.WithAcceptableSkew(jwt.AcceptableClockSkew),
}
Expand All @@ -72,6 +85,8 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie) error
opts = append(opts, jwtlib.WithRequiredClaim(jwt.SidClaim))
}

// OpenID Connect Core 3.1.3.7, step 12.
// If the `acr` Claim was requested, the Client SHOULD check that the asserted Claim Value is appropriate.
if len(clientConfig.ACRValues()) > 0 {
opts = append(opts, jwtlib.WithRequiredClaim(jwt.AcrClaim))

Expand All @@ -86,7 +101,31 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie) error
}
}

return jwtlib.Validate(in.GetToken(), opts...)
err := jwtlib.Validate(in.GetToken(), opts...)
if err != nil {
return err
}

// OpenID Connect Core 3.1.3.7, step 3.
// The `aud` (audience) Claim MAY contain an array with more than one element.
// The ID Token MUST be rejected if the ID Token [...] contains additional audiences not trusted by the Client.
audiences := in.GetToken().Audience()
if len(audiences) > 1 {
// We only trust a single audience.

untrusted := make([]string, 0)
for _, audience := range audiences {
if audience != clientConfig.ClientID() && !slices.Contains(clientConfig.Audiences(), audience) {
untrusted = append(untrusted, audience)
}
}

if len(untrusted) > 0 {
return fmt.Errorf("untrusted audience(s) found: %q", untrusted)
}
}

return nil
}

func NewIDToken(raw string, jwtToken jwtlib.Token) *IDToken {
Expand Down
245 changes: 245 additions & 0 deletions pkg/openid/tokens_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
package openid_test

import (
"fmt"
"testing"
"time"

"github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/nais/wonderwall/pkg/crypto"
"github.com/nais/wonderwall/pkg/mock"
"github.com/nais/wonderwall/pkg/openid"
)

func TestParseIDToken(t *testing.T) {
iat := time.Now().Truncate(time.Second).UTC()
exp := iat.Add(5 * time.Second)
sub := uuid.New().String()

parsed, err := makeIDToken(claims{
setClaims: map[string]any{
"sub": sub,
"iat": iat.Unix(),
"exp": exp.Unix(),
}})
require.NoError(t, err)

assert.Equal(t, sub, parsed.GetToken().Subject())
assert.Equal(t, "some-issuer", parsed.GetToken().Issuer())
assert.Equal(t, []string{"some-client-id"}, parsed.GetToken().Audience())
assert.Equal(t, "some-nonce", parsed.GetStringClaimOrEmpty("nonce"))
assert.Equal(t, "some-acr", parsed.GetStringClaimOrEmpty("acr"))
assert.Equal(t, iat, parsed.GetToken().IssuedAt())
assert.Equal(t, exp, parsed.GetToken().Expiration())
assert.NotEmpty(t, parsed.GetToken().JwtID())
}

func TestIDToken_GetAcrClaim(t *testing.T) {
idToken, err := makeIDToken()
require.NoError(t, err)

assert.Equal(t, "some-acr", idToken.GetAcrClaim())
}

func TestIDToken_GetAmrClaim(t *testing.T) {
t.Run("amr is a string", func(t *testing.T) {
idToken, err := makeIDToken(claims{
setClaims: map[string]any{
"amr": "some-amr",
}})
require.NoError(t, err)

assert.Equal(t, "some-amr", idToken.GetAmrClaim())
})

t.Run("amr is a string array", func(t *testing.T) {
idToken, err := makeIDToken(claims{
setClaims: map[string]any{
"amr": []string{"some-amr-array"},
},
})
require.NoError(t, err)

assert.Equal(t, "some-amr-array", idToken.GetAmrClaim())
})

t.Run("amr is a string array with multiple values", func(t *testing.T) {
idToken, err := makeIDToken(claims{
setClaims: map[string]any{
"amr": []string{"some-amr-1", "some-amr-2"},
},
})
require.NoError(t, err)

assert.Equal(t, "some-amr-1,some-amr-2", idToken.GetAmrClaim())
})
}

func TestIDToken_GetSidClaim(t *testing.T) {
idToken, err := makeIDToken(claims{
setClaims: map[string]any{
"sid": "some-sid",
},
})
require.NoError(t, err)

sid, err := idToken.GetSidClaim()
assert.NoError(t, err)
assert.Equal(t, "some-sid", sid)
}

func TestIDToken_Validate(t *testing.T) {
openidcfg := mock.NewTestConfiguration(mock.Config())
openidcfg.TestProvider.WithFrontChannelLogoutSupport()
cookie := &openid.LoginCookie{
Acr: "some-acr",
Nonce: "some-nonce",
}

for _, tt := range []struct {
name string
claims claims
assertion assert.ErrorAssertionFunc
}{
{
name: "happy path",
claims: claims{
setClaims: map[string]any{
"sid": "some-sid",
"aud": openidcfg.Client().ClientID(),
"iss": openidcfg.Provider().Issuer(),
},
},
assertion: assert.NoError,
},
} {
t.Run(tt.name, func(t *testing.T) {
idToken, err := makeIDToken(tt.claims)
require.NoError(t, err)

tt.assertion(t, idToken.Validate(openidcfg, cookie))
})
}

// TODO
t.Run("required claims", func(t *testing.T) {
t.Run("missing sub", func(t *testing.T) {

})

t.Run("missing exp", func(t *testing.T) {

})

t.Run("missing iat", func(t *testing.T) {

})
})

t.Run("nonce validation", func(t *testing.T) {
t.Run("missing nonce", func(t *testing.T) {

})

t.Run("nonce does not match nonce in authorization request", func(t *testing.T) {

})
})

t.Run("issuer validation", func(t *testing.T) {
t.Run("missing iss", func(t *testing.T) {

})

t.Run("issuer does not match provider issuer", func(t *testing.T) {

})

})

t.Run("sid claim is required", func(t *testing.T) {
// if provider has frontchannel_logout_supported and frontchannel_logout_session_supported

})

t.Run("acr value was requested", func(t *testing.T) {
t.Run("acr value matches", func(t *testing.T) {

})

t.Run("acr value does not match", func(t *testing.T) {

})
})

t.Run("audience validation", func(t *testing.T) {
t.Run("missing aud", func(t *testing.T) {

})

t.Run("audience does not include client_id", func(t *testing.T) {

})

t.Run("multiple trusted audiences", func(t *testing.T) {

})

t.Run("untrusted audiences", func(t *testing.T) {

})
})
}

type claims struct {
setClaims map[string]any
removeClaims []string
}

func makeIDToken(claims ...claims) (*openid.IDToken, error) {
jwks, err := crypto.NewJwkSet()
if err != nil {
return nil, fmt.Errorf("creating jwk set")
}

iat := time.Now().Truncate(time.Second).UTC()
exp := iat.Add(5 * time.Second)
sub := uuid.New().String()

idToken := jwt.New()
idToken.Set("sub", sub)
idToken.Set("iss", "some-issuer")
idToken.Set("aud", "some-client-id")
idToken.Set("nonce", "some-nonce")
idToken.Set("acr", "some-acr")
idToken.Set("iat", iat.Unix())
idToken.Set("exp", exp.Unix())
idToken.Set("jti", uuid.NewString())

if len(claims) > 0 {
for claim, claimValue := range claims[0].setClaims {
idToken.Set(claim, claimValue)
}

for _, claim := range claims[0].removeClaims {
idToken.Remove(claim)
}
}

key, ok := jwks.Private.Key(0)
if !ok {
return nil, fmt.Errorf("no private key found at index 0")
}

jws, err := jwt.Sign(idToken, jwt.WithKey(jwa.RS256, key))
if err != nil {
return nil, fmt.Errorf("signing token: %w", err)
}

return openid.ParseIDToken(string(jws), jwks.Public)
}

0 comments on commit c6dad59

Please sign in to comment.