From c6dad5999f97cd89d538b210521e477b72e17dbb Mon Sep 17 00:00:00 2001 From: Trong Huu Nguyen Date: Tue, 15 Aug 2023 16:50:55 +0200 Subject: [PATCH] feat(openid): harden id_token validation --- docs/configuration.md | 1 + pkg/config/openid.go | 3 + pkg/mock/client.go | 4 + pkg/openid/config/client.go | 5 + pkg/openid/tokens.go | 41 +++++- pkg/openid/tokens_test.go | 245 ++++++++++++++++++++++++++++++++++++ 6 files changed, 298 insertions(+), 1 deletion(-) create mode 100644 pkg/openid/tokens_test.go diff --git a/docs/configuration.md b/docs/configuration.md index 488aa754..f13a0a26 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -21,6 +21,7 @@ The following flags are available: | `log-level` | string | Logging verbosity level. | `info` | | `metrics-bind-address` | string | Listen address for metrics only. | `127.0.0.1:3001` | | `openid.acr-values` | string | Space separated string that configures the default security level (`acr_values`) parameter for authorization requests. | | +| `openid.audiences` | strings | List of additional trusted audiences (other than the client_id) for OpenID Connect id_token validation. | | | `openid.client-id` | string | Client ID for the OpenID client. | | | `openid.client-jwk` | string | JWK containing the private key for the OpenID client in string format. | | | `openid.post-logout-redirect-uri` | string | URI for redirecting the user after successful logout at the Identity Provider. | | diff --git a/pkg/config/openid.go b/pkg/config/openid.go index 8166f962..e3144ab8 100644 --- a/pkg/config/openid.go +++ b/pkg/config/openid.go @@ -6,6 +6,7 @@ import ( const ( OpenIDACRValues = "openid.acr-values" + OpenIDAudiences = "openid.audiences" OpenIDClientID = "openid.client-id" OpenIDClientJWK = "openid.client-jwk" OpenIDPostLogoutRedirectURI = "openid.post-logout-redirect-uri" @@ -18,6 +19,7 @@ const ( type OpenID struct { ACRValues string `json:"acr-values"` + Audiences []string `json:"audiences"` ClientID string `json:"client-id"` ClientJWK string `json:"client-jwk"` PostLogoutRedirectURI string `json:"post-logout-redirect-uri"` @@ -38,6 +40,7 @@ const ( func openIDFlags() { flag.String(OpenIDACRValues, "", "Space separated string that configures the default security level (acr_values) parameter for authorization requests.") + flag.StringSlice(OpenIDAudiences, []string{}, "List of additional trusted audiences (other than the client_id) for OpenID Connect id_token validation.") flag.String(OpenIDClientID, "", "Client ID for the OpenID client.") flag.String(OpenIDClientJWK, "", "JWK containing the private key for the OpenID client in string format.") flag.String(OpenIDPostLogoutRedirectURI, "", "URI for redirecting the user after successful logout at the Identity Provider.") diff --git a/pkg/mock/client.go b/pkg/mock/client.go index 3d6df33a..048cc6ad 100644 --- a/pkg/mock/client.go +++ b/pkg/mock/client.go @@ -17,6 +17,10 @@ func (c *TestClientConfiguration) ACRValues() string { return c.Config.OpenID.ACRValues } +func (c *TestClientConfiguration) Audiences() []string { + return c.Config.OpenID.Audiences +} + func (c *TestClientConfiguration) ClientID() string { return c.Config.OpenID.ClientID } diff --git a/pkg/openid/config/client.go b/pkg/openid/config/client.go index e64e00ef..d37db283 100644 --- a/pkg/openid/config/client.go +++ b/pkg/openid/config/client.go @@ -12,6 +12,7 @@ import ( type Client interface { ACRValues() string + Audiences() []string ClientID() string ClientJWK() jwk.Key PostLogoutRedirectURI() string @@ -32,6 +33,10 @@ func (in *client) ACRValues() string { return in.OpenID.ACRValues } +func (in *client) Audiences() []string { + return in.OpenID.Audiences +} + func (in *client) ClientID() string { return in.OpenID.ClientID } diff --git a/pkg/openid/tokens.go b/pkg/openid/tokens.go index 82589dac..a03d0fa6 100644 --- a/pkg/openid/tokens.go +++ b/pkg/openid/tokens.go @@ -2,6 +2,7 @@ package openid import ( "fmt" + "slices" "strings" "time" @@ -62,8 +63,20 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie) error clientConfig := cfg.Client() opts := []jwtlib.ValidateOption{ + jwtlib.WithRequiredClaim("iss"), + jwtlib.WithRequiredClaim("sub"), + jwtlib.WithRequiredClaim("aud"), + jwtlib.WithRequiredClaim("exp"), + jwtlib.WithRequiredClaim("iat"), + // OpenID Connect Core 3.1.3.7, step 3. + // The Client MUST validate that the `aud` (audience) Claim contains its `client_id` value registered at the Issuer identified by the `iss` (issuer) Claim as an audience. + // The ID Token MUST be rejected if the ID Token does not list the Client as a valid audience jwtlib.WithAudience(clientConfig.ClientID()), + // OpenID Connect Core 3.1.3.7, step 11. + // If a nonce value was sent in the Authentication Request, a `nonce` Claim MUST be present and its value checked to verify that it is the same value as the one that was sent in the Authentication Request. jwtlib.WithClaimValue("nonce", cookie.Nonce), + // OpenID Connect Core 3.1.3.7, step 2. + // The Issuer Identifier for the OpenID Provider (which is typically obtained during Discovery) MUST exactly match the value of the `iss` (issuer) Claim. jwtlib.WithIssuer(openIDconfig.Issuer()), jwtlib.WithAcceptableSkew(jwt.AcceptableClockSkew), } @@ -72,6 +85,8 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie) error opts = append(opts, jwtlib.WithRequiredClaim(jwt.SidClaim)) } + // OpenID Connect Core 3.1.3.7, step 12. + // If the `acr` Claim was requested, the Client SHOULD check that the asserted Claim Value is appropriate. if len(clientConfig.ACRValues()) > 0 { opts = append(opts, jwtlib.WithRequiredClaim(jwt.AcrClaim)) @@ -86,7 +101,31 @@ func (in *IDToken) Validate(cfg openidconfig.Config, cookie *LoginCookie) error } } - return jwtlib.Validate(in.GetToken(), opts...) + err := jwtlib.Validate(in.GetToken(), opts...) + if err != nil { + return err + } + + // OpenID Connect Core 3.1.3.7, step 3. + // The `aud` (audience) Claim MAY contain an array with more than one element. + // The ID Token MUST be rejected if the ID Token [...] contains additional audiences not trusted by the Client. + audiences := in.GetToken().Audience() + if len(audiences) > 1 { + // We only trust a single audience. + + untrusted := make([]string, 0) + for _, audience := range audiences { + if audience != clientConfig.ClientID() && !slices.Contains(clientConfig.Audiences(), audience) { + untrusted = append(untrusted, audience) + } + } + + if len(untrusted) > 0 { + return fmt.Errorf("untrusted audience(s) found: %q", untrusted) + } + } + + return nil } func NewIDToken(raw string, jwtToken jwtlib.Token) *IDToken { diff --git a/pkg/openid/tokens_test.go b/pkg/openid/tokens_test.go new file mode 100644 index 00000000..bcb0b80c --- /dev/null +++ b/pkg/openid/tokens_test.go @@ -0,0 +1,245 @@ +package openid_test + +import ( + "fmt" + "testing" + "time" + + "github.com/google/uuid" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/nais/wonderwall/pkg/crypto" + "github.com/nais/wonderwall/pkg/mock" + "github.com/nais/wonderwall/pkg/openid" +) + +func TestParseIDToken(t *testing.T) { + iat := time.Now().Truncate(time.Second).UTC() + exp := iat.Add(5 * time.Second) + sub := uuid.New().String() + + parsed, err := makeIDToken(claims{ + setClaims: map[string]any{ + "sub": sub, + "iat": iat.Unix(), + "exp": exp.Unix(), + }}) + require.NoError(t, err) + + assert.Equal(t, sub, parsed.GetToken().Subject()) + assert.Equal(t, "some-issuer", parsed.GetToken().Issuer()) + assert.Equal(t, []string{"some-client-id"}, parsed.GetToken().Audience()) + assert.Equal(t, "some-nonce", parsed.GetStringClaimOrEmpty("nonce")) + assert.Equal(t, "some-acr", parsed.GetStringClaimOrEmpty("acr")) + assert.Equal(t, iat, parsed.GetToken().IssuedAt()) + assert.Equal(t, exp, parsed.GetToken().Expiration()) + assert.NotEmpty(t, parsed.GetToken().JwtID()) +} + +func TestIDToken_GetAcrClaim(t *testing.T) { + idToken, err := makeIDToken() + require.NoError(t, err) + + assert.Equal(t, "some-acr", idToken.GetAcrClaim()) +} + +func TestIDToken_GetAmrClaim(t *testing.T) { + t.Run("amr is a string", func(t *testing.T) { + idToken, err := makeIDToken(claims{ + setClaims: map[string]any{ + "amr": "some-amr", + }}) + require.NoError(t, err) + + assert.Equal(t, "some-amr", idToken.GetAmrClaim()) + }) + + t.Run("amr is a string array", func(t *testing.T) { + idToken, err := makeIDToken(claims{ + setClaims: map[string]any{ + "amr": []string{"some-amr-array"}, + }, + }) + require.NoError(t, err) + + assert.Equal(t, "some-amr-array", idToken.GetAmrClaim()) + }) + + t.Run("amr is a string array with multiple values", func(t *testing.T) { + idToken, err := makeIDToken(claims{ + setClaims: map[string]any{ + "amr": []string{"some-amr-1", "some-amr-2"}, + }, + }) + require.NoError(t, err) + + assert.Equal(t, "some-amr-1,some-amr-2", idToken.GetAmrClaim()) + }) +} + +func TestIDToken_GetSidClaim(t *testing.T) { + idToken, err := makeIDToken(claims{ + setClaims: map[string]any{ + "sid": "some-sid", + }, + }) + require.NoError(t, err) + + sid, err := idToken.GetSidClaim() + assert.NoError(t, err) + assert.Equal(t, "some-sid", sid) +} + +func TestIDToken_Validate(t *testing.T) { + openidcfg := mock.NewTestConfiguration(mock.Config()) + openidcfg.TestProvider.WithFrontChannelLogoutSupport() + cookie := &openid.LoginCookie{ + Acr: "some-acr", + Nonce: "some-nonce", + } + + for _, tt := range []struct { + name string + claims claims + assertion assert.ErrorAssertionFunc + }{ + { + name: "happy path", + claims: claims{ + setClaims: map[string]any{ + "sid": "some-sid", + "aud": openidcfg.Client().ClientID(), + "iss": openidcfg.Provider().Issuer(), + }, + }, + assertion: assert.NoError, + }, + } { + t.Run(tt.name, func(t *testing.T) { + idToken, err := makeIDToken(tt.claims) + require.NoError(t, err) + + tt.assertion(t, idToken.Validate(openidcfg, cookie)) + }) + } + + // TODO + t.Run("required claims", func(t *testing.T) { + t.Run("missing sub", func(t *testing.T) { + + }) + + t.Run("missing exp", func(t *testing.T) { + + }) + + t.Run("missing iat", func(t *testing.T) { + + }) + }) + + t.Run("nonce validation", func(t *testing.T) { + t.Run("missing nonce", func(t *testing.T) { + + }) + + t.Run("nonce does not match nonce in authorization request", func(t *testing.T) { + + }) + }) + + t.Run("issuer validation", func(t *testing.T) { + t.Run("missing iss", func(t *testing.T) { + + }) + + t.Run("issuer does not match provider issuer", func(t *testing.T) { + + }) + + }) + + t.Run("sid claim is required", func(t *testing.T) { + // if provider has frontchannel_logout_supported and frontchannel_logout_session_supported + + }) + + t.Run("acr value was requested", func(t *testing.T) { + t.Run("acr value matches", func(t *testing.T) { + + }) + + t.Run("acr value does not match", func(t *testing.T) { + + }) + }) + + t.Run("audience validation", func(t *testing.T) { + t.Run("missing aud", func(t *testing.T) { + + }) + + t.Run("audience does not include client_id", func(t *testing.T) { + + }) + + t.Run("multiple trusted audiences", func(t *testing.T) { + + }) + + t.Run("untrusted audiences", func(t *testing.T) { + + }) + }) +} + +type claims struct { + setClaims map[string]any + removeClaims []string +} + +func makeIDToken(claims ...claims) (*openid.IDToken, error) { + jwks, err := crypto.NewJwkSet() + if err != nil { + return nil, fmt.Errorf("creating jwk set") + } + + iat := time.Now().Truncate(time.Second).UTC() + exp := iat.Add(5 * time.Second) + sub := uuid.New().String() + + idToken := jwt.New() + idToken.Set("sub", sub) + idToken.Set("iss", "some-issuer") + idToken.Set("aud", "some-client-id") + idToken.Set("nonce", "some-nonce") + idToken.Set("acr", "some-acr") + idToken.Set("iat", iat.Unix()) + idToken.Set("exp", exp.Unix()) + idToken.Set("jti", uuid.NewString()) + + if len(claims) > 0 { + for claim, claimValue := range claims[0].setClaims { + idToken.Set(claim, claimValue) + } + + for _, claim := range claims[0].removeClaims { + idToken.Remove(claim) + } + } + + key, ok := jwks.Private.Key(0) + if !ok { + return nil, fmt.Errorf("no private key found at index 0") + } + + jws, err := jwt.Sign(idToken, jwt.WithKey(jwa.RS256, key)) + if err != nil { + return nil, fmt.Errorf("signing token: %w", err) + } + + return openid.ParseIDToken(string(jws), jwks.Public) +}