From 03b22f6639129e92754652baa4538cad9caed575 Mon Sep 17 00:00:00 2001 From: Nestor Date: Fri, 21 May 2021 20:38:30 +0200 Subject: [PATCH] feat: transit from jwt-go to go-jose (#593) Closes #514 Co-authored-by: hackerman <3372410+aeneasr@users.noreply.github.com> --- authorize_request_handler.go | 21 +- ...orize_request_handler_oidc_request_test.go | 15 +- client_authentication.go | 37 +- client_authentication_test.go | 9 +- go.mod | 2 - go.sum | 3 - handler/oauth2/introspector_jwt.go | 6 +- handler/oauth2/strategy_jwt.go | 25 +- handler/openid/flow_explicit_token_test.go | 6 +- handler/openid/flow_refresh_token_test.go | 6 +- handler/openid/strategy_jwt.go | 9 +- handler/openid/validator.go | 9 +- token/jwt/claims_id_token.go | 3 +- token/jwt/claims_jwt.go | 7 +- token/jwt/header.go | 4 +- token/jwt/jwt.go | 164 +++--- token/jwt/jwt_test.go | 19 +- token/jwt/key_test.go | 48 ++ token/jwt/map_claims.go | 162 ++++++ token/jwt/map_claims_test.go | 95 ++++ token/jwt/token.go | 232 ++++++++ token/jwt/token_test.go | 495 ++++++++++++++++++ token/jwt/validation_error.go | 44 ++ 23 files changed, 1216 insertions(+), 205 deletions(-) create mode 100644 token/jwt/key_test.go create mode 100644 token/jwt/map_claims.go create mode 100644 token/jwt/map_claims_test.go create mode 100644 token/jwt/token.go create mode 100644 token/jwt/token_test.go create mode 100644 token/jwt/validation_error.go diff --git a/authorize_request_handler.go b/authorize_request_handler.go index 271a644d6..dbb1349ce 100644 --- a/authorize_request_handler.go +++ b/authorize_request_handler.go @@ -28,9 +28,10 @@ import ( "net/http" "strings" + "github.com/ory/fosite/token/jwt" "github.com/ory/x/errorsx" + "gopkg.in/square/go-jose.v2" - jwt "github.com/dgrijalva/jwt-go" "github.com/pkg/errors" "github.com/ory/go-convenience/stringslice" @@ -101,7 +102,7 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(request *Aut assertion = string(body) } - token, err := jwt.ParseWithClaims(assertion, new(jwt.MapClaims), func(t *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(assertion, jwt.MapClaims{}, func(t *jwt.Token) (interface{}, error) { // request_object_signing_alg - OPTIONAL. // JWS [JWS] alg algorithm [JWA] that MUST be used for signing Request Objects sent to the OP. All Request Objects from this Client MUST be rejected, // if not signed with this algorithm. Request Objects are described in Section 6.1 of OpenID Connect Core 1.0 [OpenID.Core]. This algorithm MUST @@ -115,22 +116,22 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(request *Aut return jwt.UnsafeAllowNoneSignatureType, nil } - switch t.Method.(type) { - case *jwt.SigningMethodRSA: + switch t.Method { + case jose.RS256, jose.RS384, jose.RS512: key, err := f.findClientPublicJWK(oidcClient, t, true) if err != nil { return nil, wrapSigningKeyFailure( ErrInvalidRequestObject.WithHint("Unable to retrieve RSA signing key from OAuth 2.0 Client."), err) } return key, nil - case *jwt.SigningMethodECDSA: + case jose.ES256, jose.ES384, jose.ES512: key, err := f.findClientPublicJWK(oidcClient, t, false) if err != nil { return nil, wrapSigningKeyFailure( ErrInvalidRequestObject.WithHint("Unable to retrieve ECDSA signing key from OAuth 2.0 Client."), err) } return key, nil - case *jwt.SigningMethodRSAPSS: + case jose.PS256, jose.PS384, jose.PS512: key, err := f.findClientPublicJWK(oidcClient, t, true) if err != nil { return nil, wrapSigningKeyFailure( @@ -155,12 +156,8 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(request *Aut return errorsx.WithStack(ErrInvalidRequestObject.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithWrap(err).WithDebug(err.Error())) } - claims, ok := token.Claims.(*jwt.MapClaims) - if !ok { - return errorsx.WithStack(ErrInvalidRequestObject.WithHint("Unable to type assert claims from request object.").WithDebugf(`Got claims of type %T but expected type '*jwt.MapClaims'.`, token.Claims)) - } - - for k, v := range *claims { + claims := token.Claims + for k, v := range claims { request.Form.Set(k, fmt.Sprintf("%s", v)) } diff --git a/authorize_request_handler_oidc_request_test.go b/authorize_request_handler_oidc_request_test.go index 2f59670ac..6cc136549 100644 --- a/authorize_request_handler_oidc_request_test.go +++ b/authorize_request_handler_oidc_request_test.go @@ -33,14 +33,14 @@ import ( "github.com/pkg/errors" - jwt "github.com/dgrijalva/jwt-go" + "github.com/ory/fosite/token/jwt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" jose "gopkg.in/square/go-jose.v2" ) func mustGenerateAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token := jwt.NewWithClaims(jose.RS256, claims) if kid != "" { token.Header["kid"] = kid } @@ -50,7 +50,7 @@ func mustGenerateAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateK } func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims) string { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token := jwt.NewWithClaims(jose.HS256, claims) tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")) require.NoError(t, err) return tokenString @@ -217,10 +217,15 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequest(t *testing.T) { if tc.expectErrReason != "" { real := new(RFC6749Error) require.True(t, errors.As(err, &real)) - assert.EqualValues(t, real.Reason(), tc.expectErrReason) + assert.EqualValues(t, tc.expectErrReason, real.Reason()) } } else { - require.NoError(t, err) + if err != nil { + real := new(RFC6749Error) + errors.As(err, &real) + require.NoErrorf(t, err, "Hint: %v\nDebug:%v", real.HintField, real.DebugField) + } + require.NoErrorf(t, err, "%+v", err) require.Equal(t, len(tc.expectForm), len(req.Form)) for k, v := range tc.expectForm { assert.EqualValues(t, v, req.Form[k]) diff --git a/client_authentication.go b/client_authentication.go index 40d248b0b..0ebf16080 100644 --- a/client_authentication.go +++ b/client_authentication.go @@ -33,7 +33,7 @@ import ( "github.com/ory/x/errorsx" - jwt "github.com/dgrijalva/jwt-go" + "github.com/ory/fosite/token/jwt" "github.com/pkg/errors" jose "gopkg.in/square/go-jose.v2" ) @@ -90,7 +90,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt var clientID string var client Client - token, err := jwt.ParseWithClaims(assertion, new(jwt.MapClaims), func(t *jwt.Token) (interface{}, error) { + token, err := jwt.ParseWithClaims(assertion, jwt.MapClaims{}, func(t *jwt.Token) (interface{}, error) { var err error clientID, _, err = clientCredentialsFromRequestBody(form, false) if err != nil { @@ -98,9 +98,8 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt } if clientID == "" { - if claims, ok := t.Claims.(*jwt.MapClaims); !ok { - return nil, errorsx.WithStack(ErrRequestUnauthorized.WithHint("Unable to type assert claims from client_assertion.").WithDebugf(`Expected claims to be of type '*jwt.MapClaims' but got '%T'.`, t.Claims)) - } else if sub, ok := (*claims)["sub"].(string); !ok { + claims := t.Claims + if sub, ok := claims["sub"].(string); !ok { return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The claim 'sub' from the client_assertion JSON Web Token is undefined.")) } else { clientID = sub @@ -135,18 +134,18 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt if oidcClient.GetTokenEndpointAuthSigningAlgorithm() != fmt.Sprintf("%s", t.Header["alg"]) { return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' uses signing algorithm '%s' but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", t.Header["alg"], oidcClient.GetTokenEndpointAuthSigningAlgorithm())) } - - if _, ok := t.Method.(*jwt.SigningMethodRSA); ok { + switch t.Method { + case jose.RS256, jose.RS384, jose.RS512: return f.findClientPublicJWK(oidcClient, t, true) - } else if _, ok := t.Method.(*jwt.SigningMethodECDSA); ok { + case jose.ES256, jose.ES384, jose.ES512: return f.findClientPublicJWK(oidcClient, t, false) - } else if _, ok := t.Method.(*jwt.SigningMethodRSAPSS); ok { + case jose.PS256, jose.PS384, jose.PS512: return f.findClientPublicJWK(oidcClient, t, true) - } else if _, ok := t.Method.(*jwt.SigningMethodHMAC); ok { + case jose.HS256, jose.HS384, jose.HS512: return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This authorization server does not support client authentication method 'client_secret_jwt'.")) + default: + return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' request parameter uses unsupported signing algorithm '%s'.", t.Header["alg"])) } - - return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' request parameter uses unsupported signing algorithm '%s'.", t.Header["alg"])) }) if err != nil { // Do not re-process already enhanced errors @@ -162,19 +161,15 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithWrap(err).WithDebug(err.Error())) } - claims, ok := token.Claims.(*jwt.MapClaims) - if !ok { - return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to type assert claims from request parameter 'client_assertion'.").WithDebugf("Got claims of type %T but expected type '*jwt.MapClaims'.", token.Claims)) - } - + claims := token.Claims var jti string if !claims.VerifyIssuer(clientID, true) { return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) } else if f.TokenURL == "" { return nil, errorsx.WithStack(ErrMisconfiguration.WithHint("The authorization server's token endpoint URL has not been set.")) - } else if sub, ok := (*claims)["sub"].(string); !ok || sub != clientID { + } else if sub, ok := claims["sub"].(string); !ok || sub != clientID { return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client.")) - } else if jti, ok = (*claims)["jti"].(string); !ok || len(jti) == 0 { + } else if jti, ok = claims["jti"].(string); !ok || len(jti) == 0 { return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'jti' from 'client_assertion' must be set but is not.")) } else if f.Store.ClientAssertionJWTValid(ctx, jti) != nil { return nil, errorsx.WithStack(ErrJTIKnown.WithHint("Claim 'jti' from 'client_assertion' MUST only be used once.")) @@ -183,7 +178,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt // type conversion according to jwt.MapClaims.VerifyExpiresAt var expiry int64 err = nil - switch exp := (*claims)["exp"].(type) { + switch exp := claims["exp"].(type) { case float64: expiry = int64(exp) case json.Number: @@ -199,7 +194,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt return nil, err } - if auds, ok := (*claims)["aud"].([]interface{}); !ok { + if auds, ok := claims["aud"].([]interface{}); !ok { if !claims.VerifyAudience(f.TokenURL, true) { return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("Claim 'audience' from 'client_assertion' must match the authorization server's token endpoint '%s'.", f.TokenURL)) } diff --git a/client_authentication_test.go b/client_authentication_test.go index bde8a4819..bf232ad66 100644 --- a/client_authentication_test.go +++ b/client_authentication_test.go @@ -34,7 +34,7 @@ import ( "testing" "time" - jwt "github.com/dgrijalva/jwt-go" + "github.com/ory/fosite/token/jwt" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -46,7 +46,7 @@ import ( ) func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + token := jwt.NewWithClaims(jose.RS256, claims) token.Header["kid"] = kid tokenString, err := token.SignedString(key) require.NoError(t, err) @@ -54,7 +54,7 @@ func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.Priva } func mustGenerateECDSAAssertion(t *testing.T, claims jwt.MapClaims, key *ecdsa.PrivateKey, kid string) string { - token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) + token := jwt.NewWithClaims(jose.ES256, claims) token.Header["kid"] = kid tokenString, err := token.SignedString(key) require.NoError(t, err) @@ -62,7 +62,7 @@ func mustGenerateECDSAAssertion(t *testing.T, claims jwt.MapClaims, key *ecdsa.P } func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string { - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token := jwt.NewWithClaims(jose.HS256, claims) tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd")) require.NoError(t, err) return tokenString @@ -503,6 +503,7 @@ func TestAuthenticateClient(t *testing.T) { t.Logf("Error is: %s", validationError.Inner) } else if errors.As(err, &rfcError) { t.Logf("DebugField is: %s", rfcError.DebugField) + t.Logf("HintField is: %s", rfcError.HintField) } } require.NoError(t, err) diff --git a/go.mod b/go.mod index 253112320..b7bd4a952 100644 --- a/go.mod +++ b/go.mod @@ -7,8 +7,6 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2 require ( github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535 github.com/dgraph-io/ristretto v0.0.3 // indirect - github.com/dgrijalva/jwt-go v3.2.0+incompatible - github.com/form3tech-oss/jwt-go v3.2.2+incompatible // indirect github.com/golang/mock v1.4.4 github.com/gorilla/mux v1.7.3 github.com/gorilla/websocket v1.4.2 diff --git a/go.sum b/go.sum index 6bbf7a7ff..9ad9eb7c1 100644 --- a/go.sum +++ b/go.sum @@ -113,10 +113,7 @@ github.com/fatih/structs v1.0.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k= -github.com/form3tech-oss/jwt-go v3.2.1+incompatible h1:xdtqez379uWVJ9P3qQMX8W+F/nqsTdUvyMZB36tnacA= github.com/form3tech-oss/jwt-go v3.2.1+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= -github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk= -github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= diff --git a/handler/oauth2/introspector_jwt.go b/handler/oauth2/introspector_jwt.go index b90ccfa6e..cf6f46f37 100644 --- a/handler/oauth2/introspector_jwt.go +++ b/handler/oauth2/introspector_jwt.go @@ -25,8 +25,6 @@ import ( "context" "time" - jwtx "github.com/dgrijalva/jwt-go" - "github.com/ory/fosite" "github.com/ory/fosite/token/jwt" ) @@ -37,8 +35,8 @@ type StatelessJWTValidator struct { } // AccessTokenJWTToRequest tries to reconstruct fosite.Request from a JWT. -func AccessTokenJWTToRequest(token *jwtx.Token) fosite.Requester { - mapClaims := token.Claims.(jwtx.MapClaims) +func AccessTokenJWTToRequest(token *jwt.Token) fosite.Requester { + mapClaims := token.Claims claims := jwt.JWTClaims{} claims.FromMapClaims(mapClaims) diff --git a/handler/oauth2/strategy_jwt.go b/handler/oauth2/strategy_jwt.go index 2234fdac0..f0b882e39 100644 --- a/handler/oauth2/strategy_jwt.go +++ b/handler/oauth2/strategy_jwt.go @@ -28,7 +28,6 @@ import ( "github.com/ory/x/errorsx" - jwtx "github.com/dgrijalva/jwt-go" "github.com/pkg/errors" "github.com/ory/fosite" @@ -99,7 +98,7 @@ func (h *DefaultJWTStrategy) ValidateAuthorizeCode(ctx context.Context, req fosi return h.HMACSHAStrategy.ValidateAuthorizeCode(ctx, req, token) } -func validate(ctx context.Context, jwtStrategy jwt.JWTStrategy, token string) (t *jwtx.Token, err error) { +func validate(ctx context.Context, jwtStrategy jwt.JWTStrategy, token string) (t *jwt.Token, err error) { t, err = jwtStrategy.Decode(ctx, token) if err == nil { @@ -107,28 +106,28 @@ func validate(ctx context.Context, jwtStrategy jwt.JWTStrategy, token string) (t } if err != nil { - var e *jwtx.ValidationError + var e *jwt.ValidationError if errors.As(err, &e) { switch e.Errors { - case jwtx.ValidationErrorMalformed: + case jwt.ValidationErrorMalformed: err = errorsx.WithStack(fosite.ErrInvalidTokenFormat.WithWrap(err).WithDebug(err.Error())) - case jwtx.ValidationErrorUnverifiable: + case jwt.ValidationErrorUnverifiable: err = errorsx.WithStack(fosite.ErrTokenSignatureMismatch.WithWrap(err).WithDebug(err.Error())) - case jwtx.ValidationErrorSignatureInvalid: + case jwt.ValidationErrorSignatureInvalid: err = errorsx.WithStack(fosite.ErrTokenSignatureMismatch.WithWrap(err).WithDebug(err.Error())) - case jwtx.ValidationErrorAudience: + case jwt.ValidationErrorAudience: err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error())) - case jwtx.ValidationErrorExpired: + case jwt.ValidationErrorExpired: err = errorsx.WithStack(fosite.ErrTokenExpired.WithWrap(err).WithDebug(err.Error())) - case jwtx.ValidationErrorIssuedAt: + case jwt.ValidationErrorIssuedAt: err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error())) - case jwtx.ValidationErrorIssuer: + case jwt.ValidationErrorIssuer: err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error())) - case jwtx.ValidationErrorNotValidYet: + case jwt.ValidationErrorNotValidYet: err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error())) - case jwtx.ValidationErrorId: + case jwt.ValidationErrorId: err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error())) - case jwtx.ValidationErrorClaimsInvalid: + case jwt.ValidationErrorClaimsInvalid: err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error())) default: err = errorsx.WithStack(fosite.ErrRequestUnauthorized.WithWrap(err).WithDebug(err.Error())) diff --git a/handler/openid/flow_explicit_token_test.go b/handler/openid/flow_explicit_token_test.go index 1ae9f7318..b4dd4f556 100644 --- a/handler/openid/flow_explicit_token_test.go +++ b/handler/openid/flow_explicit_token_test.go @@ -25,7 +25,6 @@ import ( "fmt" "testing" - jwtgo "github.com/dgrijalva/jwt-go" "github.com/golang/mock/gomock" "github.com/pkg/errors" "github.com/stretchr/testify/assert" @@ -124,10 +123,11 @@ func TestExplicit_PopulateTokenEndpointResponse(t *testing.T) { check: func(t *testing.T, aresp *fosite.AccessResponse) { assert.NotEmpty(t, aresp.GetExtra("id_token")) idToken, _ := aresp.GetExtra("id_token").(string) - decodedIdToken, _ := jwtgo.Parse(idToken, func(token *jwtgo.Token) (interface{}, error) { + decodedIdToken, err := jwt.Parse(idToken, func(token *jwt.Token) (interface{}, error) { return key.PublicKey, nil }) - claims, _ := decodedIdToken.Claims.(jwtgo.MapClaims) + require.NoError(t, err) + claims := decodedIdToken.Claims assert.NotEmpty(t, claims["at_hash"]) }, }, diff --git a/handler/openid/flow_refresh_token_test.go b/handler/openid/flow_refresh_token_test.go index 10c41390a..eada6664c 100644 --- a/handler/openid/flow_refresh_token_test.go +++ b/handler/openid/flow_refresh_token_test.go @@ -24,7 +24,6 @@ package openid import ( "testing" - jwtgo "github.com/dgrijalva/jwt-go" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -164,10 +163,11 @@ func TestOpenIDConnectRefreshHandler_PopulateTokenEndpointResponse(t *testing.T) check: func(t *testing.T, aresp *fosite.AccessResponse) { assert.NotEmpty(t, aresp.GetExtra("id_token")) idToken, _ := aresp.GetExtra("id_token").(string) - decodedIdToken, _ := jwtgo.Parse(idToken, func(token *jwtgo.Token) (interface{}, error) { + decodedIdToken, err := jwt.Parse(idToken, func(token *jwt.Token) (interface{}, error) { return key.PublicKey, nil }) - claims, _ := decodedIdToken.Claims.(jwtgo.MapClaims) + require.NoError(t, err) + claims := decodedIdToken.Claims assert.NotEmpty(t, claims["at_hash"]) }, }, diff --git a/handler/openid/strategy_jwt.go b/handler/openid/strategy_jwt.go index 77f0f0a81..47e56e4a8 100644 --- a/handler/openid/strategy_jwt.go +++ b/handler/openid/strategy_jwt.go @@ -28,7 +28,6 @@ import ( "github.com/ory/x/errorsx" - jwtgo "github.com/dgrijalva/jwt-go" "github.com/mohae/deepcopy" "github.com/pkg/errors" @@ -200,16 +199,14 @@ func (h DefaultStrategy) GenerateIDToken(ctx context.Context, requester fosite.R if tokenHintString := requester.GetRequestForm().Get("id_token_hint"); tokenHintString != "" { tokenHint, err := h.JWTStrategy.Decode(ctx, tokenHintString) - var ve *jwtgo.ValidationError - if errors.As(err, &ve) && ve.Errors == jwtgo.ValidationErrorExpired { + var ve *jwt.ValidationError + if errors.As(err, &ve) && ve.Errors == jwt.ValidationErrorExpired { // Expired ID Tokens are allowed as values to id_token_hint } else if err != nil { return "", errorsx.WithStack(fosite.ErrServerError.WithWrap(err).WithDebugf("Unable to decode id token from 'id_token_hint' parameter because %s.", err.Error())) } - if hintClaims, ok := tokenHint.Claims.(jwtgo.MapClaims); !ok { - return "", errorsx.WithStack(fosite.ErrServerError.WithDebug("Unable to decode id token from 'id_token_hint' to *jwt.StandardClaims.")) - } else if hintSub, _ := hintClaims["sub"].(string); hintSub == "" { + if hintSub, _ := tokenHint.Claims["sub"].(string); hintSub == "" { return "", errorsx.WithStack(fosite.ErrServerError.WithDebug("Provided id token from 'id_token_hint' does not have a subject.")) } else if hintSub != claims.Subject { return "", errorsx.WithStack(fosite.ErrServerError.WithDebug("Subject from authorization mismatches id token subject from 'id_token_hint'.")) diff --git a/handler/openid/validator.go b/handler/openid/validator.go index 6a079790b..d12598cff 100644 --- a/handler/openid/validator.go +++ b/handler/openid/validator.go @@ -30,7 +30,6 @@ import ( "github.com/ory/x/errorsx" - jwtgo "github.com/dgrijalva/jwt-go" "github.com/pkg/errors" "github.com/ory/fosite" @@ -158,16 +157,14 @@ func (v *OpenIDConnectRequestValidator) ValidatePrompt(ctx context.Context, req } tokenHint, err := v.Strategy.Decode(ctx, idTokenHint) - var ve *jwtgo.ValidationError - if errors.As(err, &ve) && ve.Errors == jwtgo.ValidationErrorExpired { + var ve *jwt.ValidationError + if errors.As(err, &ve) && ve.Errors == jwt.ValidationErrorExpired { // Expired tokens are ok } else if err != nil { return errorsx.WithStack(fosite.ErrInvalidRequest.WithHint("Failed to validate OpenID Connect request as decoding id token from id_token_hint parameter failed.").WithWrap(err).WithDebug(err.Error())) } - if hintClaims, ok := tokenHint.Claims.(jwtgo.MapClaims); !ok { - return errorsx.WithStack(fosite.ErrInvalidRequest.WithHint("Failed to validate OpenID Connect request as decoding id token from id_token_hint to jwtgo.MapClaims failed.")) - } else if hintSub, _ := hintClaims["sub"].(string); hintSub == "" { + if hintSub, _ := tokenHint.Claims["sub"].(string); hintSub == "" { return errorsx.WithStack(fosite.ErrInvalidRequest.WithHint("Failed to validate OpenID Connect request because provided id token from id_token_hint does not have a subject.")) } else if hintSub != claims.Subject { return errorsx.WithStack(fosite.ErrLoginRequired.WithHint("Failed to validate OpenID Connect request because the subject from provided id token from id_token_hint does not match the current session's subject.")) diff --git a/token/jwt/claims_id_token.go b/token/jwt/claims_id_token.go index ab73f29c8..ab73ba64f 100644 --- a/token/jwt/claims_id_token.go +++ b/token/jwt/claims_id_token.go @@ -24,7 +24,6 @@ package jwt import ( "time" - jwt "github.com/dgrijalva/jwt-go" "github.com/pborman/uuid" ) @@ -108,6 +107,6 @@ func (c *IDTokenClaims) Get(key string) interface{} { } // ToMapClaims will return a jwt-go MapClaims representation -func (c IDTokenClaims) ToMapClaims() jwt.MapClaims { +func (c IDTokenClaims) ToMapClaims() MapClaims { return c.ToMap() } diff --git a/token/jwt/claims_jwt.go b/token/jwt/claims_jwt.go index fb5986a65..88e04758e 100644 --- a/token/jwt/claims_jwt.go +++ b/token/jwt/claims_jwt.go @@ -25,7 +25,6 @@ import ( "strings" "time" - jwt "github.com/dgrijalva/jwt-go" "github.com/pborman/uuid" ) @@ -58,7 +57,7 @@ type JWTClaimsContainer interface { WithScopeField(scopeField JWTScopeFieldEnum) JWTClaimsContainer // ToMapClaims returns the claims as a github.com/dgrijalva/jwt-go.MapClaims type. - ToMapClaims() jwt.MapClaims + ToMapClaims() MapClaims } // JWTClaims represent a token's claims. @@ -229,11 +228,11 @@ func (c JWTClaims) Get(key string) interface{} { } // ToMapClaims will return a jwt-go MapClaims representation -func (c JWTClaims) ToMapClaims() jwt.MapClaims { +func (c JWTClaims) ToMapClaims() MapClaims { return c.ToMap() } // FromMapClaims will populate claims from a jwt-go MapClaims representation -func (c *JWTClaims) FromMapClaims(mc jwt.MapClaims) { +func (c *JWTClaims) FromMapClaims(mc MapClaims) { c.FromMap(mc) } diff --git a/token/jwt/header.go b/token/jwt/header.go index f847bf873..f790a2300 100644 --- a/token/jwt/header.go +++ b/token/jwt/header.go @@ -21,8 +21,6 @@ package jwt -import jwt "github.com/dgrijalva/jwt-go" - // Headers is the jwt headers type Headers struct { Extra map[string]interface{} @@ -61,6 +59,6 @@ func (h *Headers) Get(key string) interface{} { } // ToMapClaims will return a jwt-go MapClaims representation -func (h Headers) ToMapClaims() jwt.MapClaims { +func (h Headers) ToMapClaims() MapClaims { return h.ToMap() } diff --git a/token/jwt/jwt.go b/token/jwt/jwt.go index f052b602b..52b91e44d 100644 --- a/token/jwt/jwt.go +++ b/token/jwt/jwt.go @@ -26,107 +26,62 @@ package jwt import ( "context" + "crypto" "crypto/ecdsa" "crypto/rsa" "crypto/sha256" - "fmt" "strings" "github.com/ory/x/errorsx" + "gopkg.in/square/go-jose.v2" - jwt "github.com/dgrijalva/jwt-go" "github.com/pkg/errors" - - "github.com/ory/fosite" ) type JWTStrategy interface { - Generate(ctx context.Context, claims jwt.Claims, header Mapper) (string, string, error) + Generate(ctx context.Context, claims MapClaims, header Mapper) (string, string, error) Validate(ctx context.Context, token string) (string, error) Hash(ctx context.Context, in []byte) ([]byte, error) - Decode(ctx context.Context, token string) (*jwt.Token, error) + Decode(ctx context.Context, token string) (*Token, error) GetSignature(ctx context.Context, token string) (string, error) GetSigningMethodLength() int } +var SHA256HashSize = crypto.SHA256.Size() + // RS256JWTStrategy is responsible for generating and validating JWT challenges type RS256JWTStrategy struct { PrivateKey *rsa.PrivateKey } // Generate generates a new authorize code or returns an error. set secret -func (j *RS256JWTStrategy) Generate(ctx context.Context, claims jwt.Claims, header Mapper) (string, string, error) { - if header == nil || claims == nil { - return "", "", errors.New("Either claims or header is nil.") - } - - token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) - token.Header = assign(token.Header, header.ToMap()) - - var sig, sstr string - var err error - if sstr, err = token.SigningString(); err != nil { - return "", "", errorsx.WithStack(err) - } - - if sig, err = token.Method.Sign(sstr, j.PrivateKey); err != nil { - return "", "", errorsx.WithStack(err) - } - - return fmt.Sprintf("%s.%s", sstr, sig), sig, nil +func (j *RS256JWTStrategy) Generate(ctx context.Context, claims MapClaims, header Mapper) (string, string, error) { + return generateToken(claims, header, jose.RS256, j.PrivateKey) } // Validate validates a token and returns its signature or an error if the token is not valid. func (j *RS256JWTStrategy) Validate(ctx context.Context, token string) (string, error) { - if _, err := j.Decode(ctx, token); err != nil { - return "", errorsx.WithStack(err) - } - - return j.GetSignature(ctx, token) + return validateToken(token, &j.PrivateKey.PublicKey) } // Decode will decode a JWT token -func (j *RS256JWTStrategy) Decode(ctx context.Context, token string) (*jwt.Token, error) { - // Parse the token. - parsedToken, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { - if _, ok := t.Method.(*jwt.SigningMethodRSA); !ok { - return nil, errors.Errorf("Unexpected signing method: %v", t.Header["alg"]) - } - return &j.PrivateKey.PublicKey, nil - }) - - if err != nil { - return parsedToken, errorsx.WithStack(err) - } else if !parsedToken.Valid { - return parsedToken, errorsx.WithStack(fosite.ErrInactiveToken) - } - - return parsedToken, err +func (j *RS256JWTStrategy) Decode(ctx context.Context, token string) (*Token, error) { + return decodeToken(token, &j.PrivateKey.PublicKey) } // GetSignature will return the signature of a token func (j *RS256JWTStrategy) GetSignature(ctx context.Context, token string) (string, error) { - split := strings.Split(token, ".") - if len(split) != 3 { - return "", errors.New("Header, body and signature must all be set") - } - return split[2], nil + return getTokenSignature(token) } // Hash will return a given hash based on the byte input or an error upon fail func (j *RS256JWTStrategy) Hash(ctx context.Context, in []byte) ([]byte, error) { - // SigningMethodRS256 - hash := sha256.New() - _, err := hash.Write(in) - if err != nil { - return []byte{}, errorsx.WithStack(err) - } - return hash.Sum([]byte{}), nil + return hashSHA256(in) } // GetSigningMethodLength will return the length of the signing method func (j *RS256JWTStrategy) GetSigningMethodLength() int { - return jwt.SigningMethodRS256.Hash.Size() + return SHA256HashSize } // ES256JWTStrategy is responsible for generating and validating JWT challenges @@ -135,57 +90,67 @@ type ES256JWTStrategy struct { } // Generate generates a new authorize code or returns an error. set secret -func (j *ES256JWTStrategy) Generate(ctx context.Context, claims jwt.Claims, header Mapper) (string, string, error) { - if header == nil || claims == nil { - return "", "", errors.New("Either claims or header is nil.") - } +func (j *ES256JWTStrategy) Generate(ctx context.Context, claims MapClaims, header Mapper) (string, string, error) { + return generateToken(claims, header, jose.ES256, j.PrivateKey) +} - token := jwt.NewWithClaims(jwt.SigningMethodES256, claims) - token.Header = assign(token.Header, header.ToMap()) +// Validate validates a token and returns its signature or an error if the token is not valid. +func (j *ES256JWTStrategy) Validate(ctx context.Context, token string) (string, error) { + return validateToken(token, &j.PrivateKey.PublicKey) +} - var sig, sstr string - var err error - if sstr, err = token.SigningString(); err != nil { - return "", "", errorsx.WithStack(err) - } +// Decode will decode a JWT token +func (j *ES256JWTStrategy) Decode(ctx context.Context, token string) (*Token, error) { + return decodeToken(token, &j.PrivateKey.PublicKey) +} - if sig, err = token.Method.Sign(sstr, j.PrivateKey); err != nil { - return "", "", errorsx.WithStack(err) - } +// GetSignature will return the signature of a token +func (j *ES256JWTStrategy) GetSignature(ctx context.Context, token string) (string, error) { + return getTokenSignature(token) +} - return fmt.Sprintf("%s.%s", sstr, sig), sig, nil +// Hash will return a given hash based on the byte input or an error upon fail +func (j *ES256JWTStrategy) Hash(ctx context.Context, in []byte) ([]byte, error) { + return hashSHA256(in) } -// Validate validates a token and returns its signature or an error if the token is not valid. -func (j *ES256JWTStrategy) Validate(ctx context.Context, token string) (string, error) { - if _, err := j.Decode(ctx, token); err != nil { - return "", errorsx.WithStack(err) +// GetSigningMethodLength will return the length of the signing method +func (j *ES256JWTStrategy) GetSigningMethodLength() int { + return SHA256HashSize +} + +func generateToken(claims MapClaims, header Mapper, signingMethod jose.SignatureAlgorithm, privateKey interface{}) (rawToken string, sig string, err error) { + if header == nil || claims == nil { + err = errors.New("Either claims or header is nil.") + return } - return j.GetSignature(ctx, token) + token := NewWithClaims(signingMethod, claims) + token.Header = assign(token.Header, header.ToMap()) + + rawToken, err = token.SignedString(privateKey) + if err != nil { + return + } + + sig, err = getTokenSignature(rawToken) + return } -// Decode will decode a JWT token -func (j *ES256JWTStrategy) Decode(ctx context.Context, token string) (*jwt.Token, error) { - // Parse the token. - parsedToken, err := jwt.Parse(token, func(t *jwt.Token) (interface{}, error) { - if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { - return nil, errors.Errorf("Unexpected signing method: %v", t.Header["alg"]) - } - return &j.PrivateKey.PublicKey, nil - }) +func decodeToken(token string, verificationKey interface{}) (*Token, error) { + keyFunc := func(*Token) (interface{}, error) { return verificationKey, nil } + return ParseWithClaims(token, MapClaims{}, keyFunc) +} +func validateToken(tokenStr string, verificationKey interface{}) (string, error) { + _, err := decodeToken(tokenStr, verificationKey) if err != nil { - return parsedToken, errorsx.WithStack(err) - } else if !parsedToken.Valid { - return parsedToken, errorsx.WithStack(fosite.ErrInactiveToken) + return "", err } - - return parsedToken, err + return getTokenSignature(tokenStr) } -// GetSignature will return the signature of a token -func (j *ES256JWTStrategy) GetSignature(ctx context.Context, token string) (string, error) { +func getTokenSignature(token string) (string, error) { split := strings.Split(token, ".") if len(split) != 3 { return "", errors.New("Header, body and signature must all be set") @@ -193,9 +158,7 @@ func (j *ES256JWTStrategy) GetSignature(ctx context.Context, token string) (stri return split[2], nil } -// Hash will return a given hash based on the byte input or an error upon fail -func (j *ES256JWTStrategy) Hash(ctx context.Context, in []byte) ([]byte, error) { - // SigningMethodES256 +func hashSHA256(in []byte) ([]byte, error) { hash := sha256.New() _, err := hash.Write(in) if err != nil { @@ -204,11 +167,6 @@ func (j *ES256JWTStrategy) Hash(ctx context.Context, in []byte) ([]byte, error) return hash.Sum([]byte{}), nil } -// GetSigningMethodLength will return the length of the signing method -func (j *ES256JWTStrategy) GetSigningMethodLength() int { - return jwt.SigningMethodES256.Hash.Size() -} - func assign(a, b map[string]interface{}) map[string]interface{} { for k, w := range b { if _, ok := a[k]; ok { diff --git a/token/jwt/jwt_test.go b/token/jwt/jwt_test.go index b822a76a2..72f981076 100644 --- a/token/jwt/jwt_test.go +++ b/token/jwt/jwt_test.go @@ -30,8 +30,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - - "github.com/ory/fosite/internal" ) var header = &Headers{ @@ -48,13 +46,13 @@ func TestHash(t *testing.T) { { d: "RS256JWTStrategy", strategy: &RS256JWTStrategy{ - PrivateKey: internal.MustRSAKey(), + PrivateKey: MustRSAKey(), }, }, { d: "ES256JWTStrategy", strategy: &ES256JWTStrategy{ - PrivateKey: internal.MustECDSAKey(), + PrivateKey: MustECDSAKey(), }, }, } { @@ -103,19 +101,19 @@ func TestGenerateJWT(t *testing.T) { { d: "RS256JWTStrategy", strategy: &RS256JWTStrategy{ - PrivateKey: internal.MustRSAKey(), + PrivateKey: MustRSAKey(), }, resetKey: func(strategy JWTStrategy) { - strategy.(*RS256JWTStrategy).PrivateKey = internal.MustRSAKey() + strategy.(*RS256JWTStrategy).PrivateKey = MustRSAKey() }, }, { d: "ES256JWTStrategy", strategy: &ES256JWTStrategy{ - PrivateKey: internal.MustECDSAKey(), + PrivateKey: MustECDSAKey(), }, resetKey: func(strategy JWTStrategy) { - strategy.(*ES256JWTStrategy).PrivateKey = internal.MustECDSAKey() + strategy.(*ES256JWTStrategy).PrivateKey = MustECDSAKey() }, }, } { @@ -149,7 +147,6 @@ func TestGenerateJWT(t *testing.T) { token, sig, err = tc.strategy.Generate(context.TODO(), claims.ToMapClaims(), header) require.NoError(t, err) require.NotNil(t, token) - //t.Logf("%s.%s", token, sig) sig, err = tc.strategy.Validate(context.TODO(), token) require.Error(t, err) @@ -177,13 +174,13 @@ func TestValidateSignatureRejectsJWT(t *testing.T) { { d: "RS256JWTStrategy", strategy: &RS256JWTStrategy{ - PrivateKey: internal.MustRSAKey(), + PrivateKey: MustRSAKey(), }, }, { d: "ES256JWTStrategy", strategy: &ES256JWTStrategy{ - PrivateKey: internal.MustECDSAKey(), + PrivateKey: MustECDSAKey(), }, }, } { diff --git a/token/jwt/key_test.go b/token/jwt/key_test.go new file mode 100644 index 000000000..b0c723691 --- /dev/null +++ b/token/jwt/key_test.go @@ -0,0 +1,48 @@ +/* + * Copyright © 2015-2018 Aeneas Rekkas + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Aeneas Rekkas + * @copyright 2015-2018 Aeneas Rekkas + * @license Apache-2.0 + * + */ + +// REMARK: Copied here from fosite/internal to avoid circular dependencies only for test data + +package jwt + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" +) + +func MustRSAKey() *rsa.PrivateKey { + // #nosec + key, err := rsa.GenerateKey(rand.Reader, 1024) + if err != nil { + panic(err) + } + return key +} + +func MustECDSAKey() *ecdsa.PrivateKey { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + panic(err) + } + return key +} diff --git a/token/jwt/map_claims.go b/token/jwt/map_claims.go new file mode 100644 index 000000000..c5b4e988f --- /dev/null +++ b/token/jwt/map_claims.go @@ -0,0 +1,162 @@ +package jwt + +import ( + "crypto/subtle" + "encoding/json" + "errors" + "time" + // "fmt" +) + +var TimeFunc = time.Now + +// MapClaims provides backwards compatible validations not available in `go-jose`. +// It was taken from [here](https://raw.githubusercontent.com/form3tech-oss/jwt-go/master/map_claims.go). +// +// Claims type that uses the map[string]interface{} for JSON decoding +// This is the default claims type if you don't supply one +type MapClaims map[string]interface{} + +// Compares the aud claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyAudience(cmp string, req bool) bool { + var aud []string + switch v := m["aud"].(type) { + case []string: + aud = v + case []interface{}: + for _, a := range v { + vs, ok := a.(string) + if !ok { + return false + } + aud = append(aud, vs) + } + case string: + aud = append(aud, v) + default: + return false + } + return verifyAud(aud, cmp, req) +} + +// Compares the exp claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyExpiresAt(cmp int64, req bool) bool { + switch exp := m["exp"].(type) { + case float64: + return verifyExp(int64(exp), cmp, req) + case json.Number: + v, _ := exp.Int64() + return verifyExp(v, cmp, req) + } + return !req +} + +// Compares the iat claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyIssuedAt(cmp int64, req bool) bool { + switch iat := m["iat"].(type) { + case float64: + return verifyIat(int64(iat), cmp, req) + case json.Number: + v, _ := iat.Int64() + return verifyIat(v, cmp, req) + } + return !req +} + +// Compares the iss claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyIssuer(cmp string, req bool) bool { + iss, _ := m["iss"].(string) + return verifyIss(iss, cmp, req) +} + +// Compares the nbf claim against cmp. +// If required is false, this method will return true if the value matches or is unset +func (m MapClaims) VerifyNotBefore(cmp int64, req bool) bool { + switch nbf := m["nbf"].(type) { + case float64: + return verifyNbf(int64(nbf), cmp, req) + case json.Number: + v, _ := nbf.Int64() + return verifyNbf(v, cmp, req) + } + return !req +} + +// Validates time based claims "exp, iat, nbf". +// There is no accounting for clock skew. +// As well, if any of the above claims are not in the token, it will still +// be considered a valid claim. +func (m MapClaims) Valid() error { + vErr := new(ValidationError) + now := TimeFunc().Unix() + + if !m.VerifyExpiresAt(now, false) { + vErr.Inner = errors.New("Token is expired") + vErr.Errors |= ValidationErrorExpired + } + + if !m.VerifyIssuedAt(now, false) { + vErr.Inner = errors.New("Token used before issued") + vErr.Errors |= ValidationErrorIssuedAt + } + + if !m.VerifyNotBefore(now, false) { + vErr.Inner = errors.New("Token is not valid yet") + vErr.Errors |= ValidationErrorNotValidYet + } + + if vErr.valid() { + return nil + } + + return vErr +} + +func verifyAud(aud []string, cmp string, required bool) bool { + if len(aud) == 0 { + return !required + } + + for _, a := range aud { + if subtle.ConstantTimeCompare([]byte(a), []byte(cmp)) != 0 { + return true + } + } + return false +} + +func verifyExp(exp int64, now int64, required bool) bool { + if exp == 0 { + return !required + } + return now <= exp +} + +func verifyIat(iat int64, now int64, required bool) bool { + if iat == 0 { + return !required + } + return now >= iat +} + +func verifyIss(iss string, cmp string, required bool) bool { + if iss == "" { + return !required + } + if subtle.ConstantTimeCompare([]byte(iss), []byte(cmp)) != 0 { + return true + } else { + return false + } +} + +func verifyNbf(nbf int64, now int64, required bool) bool { + if nbf == 0 { + return !required + } + return now >= nbf +} diff --git a/token/jwt/map_claims_test.go b/token/jwt/map_claims_test.go new file mode 100644 index 000000000..e264dda56 --- /dev/null +++ b/token/jwt/map_claims_test.go @@ -0,0 +1,95 @@ +package jwt + +import "testing" + +// Test taken from taken from [here](https://raw.githubusercontent.com/form3tech-oss/jwt-go/master/map_claims_test.go). +func Test_mapClaims_list_aud(t *testing.T) { + mapClaims := MapClaims{ + "aud": []string{"foo"}, + } + want := true + got := mapClaims.VerifyAudience("foo", true) + + if want != got { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + } +} + +// This is a custom test to check that an empty +// list with require == false returns valid +func Test_mapClaims_empty_list_aud(t *testing.T) { + mapClaims := MapClaims{ + "aud": []string{}, + } + want := true + got := mapClaims.VerifyAudience("foo", false) + + if want != got { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + } +} +func Test_mapClaims_list_interface_aud(t *testing.T) { + mapClaims := MapClaims{ + "aud": []interface{}{"foo"}, + } + want := true + got := mapClaims.VerifyAudience("foo", true) + + if want != got { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + } +} +func Test_mapClaims_string_aud(t *testing.T) { + mapClaims := MapClaims{ + "aud": "foo", + } + want := true + got := mapClaims.VerifyAudience("foo", true) + + if want != got { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + } +} + +func Test_mapClaims_list_aud_no_match(t *testing.T) { + mapClaims := MapClaims{ + "aud": []string{"bar"}, + } + want := false + got := mapClaims.VerifyAudience("foo", true) + + if want != got { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + } +} +func Test_mapClaims_string_aud_fail(t *testing.T) { + mapClaims := MapClaims{ + "aud": "bar", + } + want := false + got := mapClaims.VerifyAudience("foo", true) + + if want != got { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + } +} + +func Test_mapClaims_string_aud_no_claim(t *testing.T) { + mapClaims := MapClaims{} + want := false + got := mapClaims.VerifyAudience("foo", true) + + if want != got { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + } +} + +func Test_mapClaims_string_aud_no_claim_not_required(t *testing.T) { + mapClaims := MapClaims{} + want := false + got := mapClaims.VerifyAudience("foo", false) + + if want != got { + t.Fatalf("Failed to verify claims, wanted: %v got %v", want, got) + } +} diff --git a/token/jwt/token.go b/token/jwt/token.go new file mode 100644 index 000000000..37c8f4666 --- /dev/null +++ b/token/jwt/token.go @@ -0,0 +1,232 @@ +package jwt + +import ( + "encoding/base64" + "encoding/json" + "fmt" + "reflect" + + "github.com/ory/x/errorsx" + "gopkg.in/square/go-jose.v2" + "gopkg.in/square/go-jose.v2/jwt" +) + +// Token represets a JWT Token +// This token provide an adaptation to +// transit from [jwt-go](https://github.com/dgrijalva/jwt-go) +// to [go-jose](https://github.com/square/go-jose) +// It provides method signatures compatible with jwt-go but implemented +// using go-json +type Token struct { + Header map[string]interface{} // The first segment of the token + Claims MapClaims // The second segment of the token + Method jose.SignatureAlgorithm + valid bool +} + +const ( + SigningMethodNone = jose.SignatureAlgorithm("none") + // This key should be use to correctly sign and verify alg:none JWT tokens + UnsafeAllowNoneSignatureType unsafeNoneMagicConstant = "none signing method allowed" +) + +type unsafeNoneMagicConstant string + +// Valid informs if the token was verified against a given verification key +// and claims are valid +func (t *Token) Valid() bool { + return t.valid +} + +// Claims is a port from https://github.com/dgrijalva/jwt-go/blob/master/claims.go +// including its validation methods, which are not available in go-jose library +// +// > For a type to be a Claims object, it must just have a Valid method that determines +// if the token is invalid for any supported reason +type Claims interface { + Valid() error +} + +// NewWithClaims creates an unverified Token with the given claims and signing method +func NewWithClaims(method jose.SignatureAlgorithm, claims MapClaims) *Token { + return &Token{ + Claims: claims, + Method: method, + Header: map[string]interface{}{}, + } +} + +func (t *Token) toJoseHeader() map[jose.HeaderKey]interface{} { + h := map[jose.HeaderKey]interface{}{} + for k, v := range t.Header { + h[jose.HeaderKey(k)] = v + } + return h +} + +// SignedString provides a compatible `jwt-go` Token.SignedString method +// +// > Get the complete, signed token +func (t *Token) SignedString(k interface{}) (rawToken string, err error) { + if _, ok := k.(unsafeNoneMagicConstant); ok { + rawToken, err = unsignedToken(t) + return + + } + var signer jose.Signer + key := jose.SigningKey{ + Algorithm: t.Method, + Key: k, + } + opts := &jose.SignerOptions{ExtraHeaders: t.toJoseHeader()} + signer, err = jose.NewSigner(key, opts) + if err != nil { + err = errorsx.WithStack(err) + return + } + + // A explicit conversion from type alias MapClaims + // to map[string]interface{} is required because the + // go-jose CompactSerialize() only support explicit maps + // as claims or structs but not type aliases from maps. + claims := map[string]interface{}(t.Claims) + rawToken, err = jwt.Signed(signer).Claims(claims).CompactSerialize() + if err != nil { + err = &ValidationError{Errors: ValidationErrorClaimsInvalid, Inner: err} + return + } + return +} + +func unsignedToken(t *Token) (string, error) { + t.Header["alg"] = "none" + hbytes, err := json.Marshal(&t.Header) + if err != nil { + return "", errorsx.WithStack(err) + } + bbytes, err := json.Marshal(&t.Claims) + if err != nil { + return "", errorsx.WithStack(err) + } + h := base64.RawURLEncoding.EncodeToString(hbytes) + b := base64.RawURLEncoding.EncodeToString(bbytes) + return fmt.Sprintf("%v.%v.", h, b), nil +} + +func newToken(parsedToken *jwt.JSONWebToken, claims MapClaims) (*Token, error) { + token := &Token{Claims: claims} + if len(parsedToken.Headers) != 1 { + return nil, &ValidationError{text: fmt.Sprintf("only one header supported, got %v", len(parsedToken.Headers)), Errors: ValidationErrorMalformed} + } + + // copy headers + h := parsedToken.Headers[0] + token.Header = map[string]interface{}{ + "alg": h.Algorithm, + } + if h.KeyID != "" { + token.Header["kid"] = h.KeyID + } + for k, v := range h.ExtraHeaders { + token.Header[string(k)] = v + } + + token.Method = jose.SignatureAlgorithm(h.Algorithm) + + return token, nil +} + +// Parse methods use this callback function to supply +// the key for verification. The function receives the parsed, +// but unverified Token. This allows you to use properties in the +// Header of the token (such as `kid`) to identify which key to use. +type Keyfunc func(*Token) (interface{}, error) + +func Parse(tokenString string, keyFunc Keyfunc) (*Token, error) { + return ParseWithClaims(tokenString, MapClaims{}, keyFunc) +} + +// Parse, validate, and return a token. +// keyFunc will receive the parsed token and should return the key for validating. +// If everything is kosher, err will be nil +func ParseWithClaims(rawToken string, claims MapClaims, keyFunc Keyfunc) (*Token, error) { + // Parse the token. + parsedToken, err := jwt.ParseSigned(rawToken) + if err != nil { + return &Token{}, &ValidationError{Errors: ValidationErrorMalformed, text: err.Error()} + } + + // fill unverified claims + // This conversion is required because go-jose supports + // only marshalling structs or maps but not alias types from maps + // + // The KeyFunc(*Token) function requires the claims to be set into the + // Token, that is an unverified token, therefore an UnsafeClaimsWithoutVerification is done first + // then with the returned key, the claims gets verified. + if err := parsedToken.UnsafeClaimsWithoutVerification(&claims); err != nil { + return nil, &ValidationError{Errors: ValidationErrorClaimsInvalid, text: err.Error()} + } + + // creates an usafe token + token, err := newToken(parsedToken, claims) + if err != nil { + return nil, err + } + + if keyFunc == nil { + // keyFunc was not provided. short circuiting validation + return token, &ValidationError{Errors: ValidationErrorUnverifiable, text: "no Keyfunc was provided."} + } + + // Call keyFunc callback to get verification key + verificationKey, err := keyFunc(token) + if err != nil { + // keyFunc returned an error + if ve, ok := err.(*ValidationError); ok { + return token, ve + } + return token, &ValidationError{Errors: ValidationErrorUnverifiable, Inner: err} + } + if verificationKey == nil { + return token, &ValidationError{Errors: ValidationErrorSignatureInvalid, text: "keyfunc returned a nil verification key"} + } + // To verify signature go-jose requires a pointer to + // public key instead of the public key value. + // The pointer values provides that pointer. + // E.g. transform rsa.PublicKey -> *rsa.PublicKey + verificationKey = pointer(verificationKey) + + // verify signature with returned key + _, validNoneKey := verificationKey.(*unsafeNoneMagicConstant) + isSignedToken := !(token.Method == SigningMethodNone && validNoneKey) + if isSignedToken { + if err := parsedToken.Claims(verificationKey, &claims); err != nil { + return token, &ValidationError{Errors: ValidationErrorSignatureInvalid, text: err.Error()} + } + } + + // Validate claims + // This validation is performed to be backwards compatible + // with jwt-go library behavior + if err := claims.Valid(); err != nil { + if e, ok := err.(*ValidationError); !ok { + err = &ValidationError{Inner: e, Errors: ValidationErrorClaimsInvalid} + } + return token, err + } + + // set token as verified and validated + token.valid = true + return token, nil +} + +// if underline value of v is not a pointer +// it creates a pointer of it and returns it +func pointer(v interface{}) interface{} { + if reflect.ValueOf(v).Kind() != reflect.Ptr { + value := reflect.New(reflect.ValueOf(v).Type()) + value.Elem().Set(reflect.ValueOf(v)) + return value.Interface() + } + return v +} diff --git a/token/jwt/token_test.go b/token/jwt/token_test.go new file mode 100644 index 000000000..f050f0af0 --- /dev/null +++ b/token/jwt/token_test.go @@ -0,0 +1,495 @@ +package jwt + +import ( + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gopkg.in/square/go-jose.v2" + + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" +) + +func TestUnsignedToken(t *testing.T) { + key := UnsafeAllowNoneSignatureType + token := NewWithClaims(SigningMethodNone, MapClaims{ + "aud": "foo", + "exp": time.Now().UTC().Add(time.Hour).Unix(), + "iat": time.Now().UTC().Unix(), + "sub": "nestor", + }) + rawToken, err := token.SignedString(key) + require.NoError(t, err) + require.NotEmpty(t, rawToken) + parts := strings.Split(rawToken, ".") + require.Len(t, parts, 3) + require.Empty(t, parts[2]) +} + +var keyFuncError error = fmt.Errorf("error loading key") +var ( + jwtTestDefaultKey *rsa.PublicKey = parseRSAPublicKeyFromPEM(defaultPubKeyPEM) + defaultKeyFunc Keyfunc = func(t *Token) (interface{}, error) { return jwtTestDefaultKey, nil } + emptyKeyFunc Keyfunc = func(t *Token) (interface{}, error) { return nil, nil } + errorKeyFunc Keyfunc = func(t *Token) (interface{}, error) { return nil, keyFuncError } + nilKeyFunc Keyfunc = nil +) + +// Many test cases where taken from https://github.com/dgrijalva/jwt-go/blob/master/parser_test.go +// Test cases related to json.Number where excluded because that is not supported by go-jose, +// it is not used in fosite and therefore not supported. +func TestParser_Parse(t *testing.T) { + var ( + defaultES256PrivateKey = MustECDSAKey() + defaultSigningKey = parseRSAPrivateKeyFromPEM(defaultPrivateKeyPEM) + publicECDSAKey = func(*Token) (interface{}, error) { return &defaultES256PrivateKey.PublicKey, nil } + noneKey = func(*Token) (interface{}, error) { return UnsafeAllowNoneSignatureType, nil } + randomKey = func(*Token) (interface{}, error) { + k, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + return &k.PublicKey, nil + } + ) + type expected struct { + errors uint32 + keyFunc Keyfunc + valid bool + claims MapClaims + } + type generate struct { + claims MapClaims + signingKey interface{} // defaultSigningKey + method jose.SignatureAlgorithm // default RS256 + } + type given struct { + name string + tokenString string + generate *generate + } + var jwtTestData = []struct { + expected + given + }{ + { + given: given{ + name: "basic", + tokenString: "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + }, + expected: expected{ + keyFunc: defaultKeyFunc, + claims: MapClaims{"foo": "bar"}, + valid: true, + errors: 0, + }, + }, + { + given: given{ + name: "basic expired", + generate: &generate{ + claims: MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, + }, + }, + expected: expected{ + keyFunc: defaultKeyFunc, + claims: MapClaims{"foo": "bar", "exp": float64(time.Now().Unix() - 100)}, + valid: false, + errors: ValidationErrorExpired, + }, + }, + { + given: given{ + name: "basic nbf", + generate: &generate{ + claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, + }, + }, + expected: expected{ + keyFunc: defaultKeyFunc, + claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100)}, + valid: false, + errors: ValidationErrorNotValidYet, + }, + }, + { + given: given{ + name: "expired and nbf", + generate: &generate{ + claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, + }, + }, + expected: expected{ + keyFunc: defaultKeyFunc, + claims: MapClaims{"foo": "bar", "nbf": float64(time.Now().Unix() + 100), "exp": float64(time.Now().Unix() - 100)}, + valid: false, + errors: ValidationErrorNotValidYet | ValidationErrorExpired, + }, + }, + { + given: given{ + name: "basic invalid", + tokenString: "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.EhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + }, + expected: expected{ + keyFunc: defaultKeyFunc, + claims: MapClaims{"foo": "bar"}, + valid: false, + errors: ValidationErrorSignatureInvalid, + }, + }, + { + given: given{ + name: "basic nokeyfunc", + tokenString: "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + }, + expected: expected{ + keyFunc: nilKeyFunc, + claims: MapClaims{"foo": "bar"}, + valid: false, + errors: ValidationErrorUnverifiable, + }, + }, + { + given: given{ + name: "basic nokey", + tokenString: "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + }, + expected: expected{ + keyFunc: emptyKeyFunc, + claims: MapClaims{"foo": "bar"}, + valid: false, + errors: ValidationErrorSignatureInvalid, + }, + }, + { + given: given{ + name: "basic errorkey", + tokenString: "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + generate: &generate{ + claims: MapClaims{"foo": "bar"}, + }, + }, + expected: expected{ + keyFunc: errorKeyFunc, + claims: MapClaims{"foo": "bar"}, + valid: false, + errors: ValidationErrorUnverifiable, + }, + }, + { + given: given{ + name: "valid signing method", + generate: &generate{ + claims: MapClaims{"foo": "bar"}, + }, + }, + expected: expected{ + keyFunc: defaultKeyFunc, + claims: MapClaims{"foo": "bar"}, + valid: true, + errors: 0, + }, + }, + { + given: given{ + name: "invalid", + tokenString: "foo_invalid_token", + }, + expected: expected{ + keyFunc: defaultKeyFunc, + claims: MapClaims(nil), + valid: false, + errors: ValidationErrorMalformed, + }, + }, + { + given: given{ + name: "valid format invalid content", + tokenString: "foo.bar.baz", + }, + expected: expected{ + keyFunc: defaultKeyFunc, + claims: MapClaims(nil), + valid: false, + errors: ValidationErrorMalformed, + }, + }, + { + given: given{ + name: "wrong key, expected ECDSA got RSA", + tokenString: "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + }, + expected: expected{ + keyFunc: publicECDSAKey, + claims: MapClaims{"foo": "bar"}, + valid: false, + errors: ValidationErrorSignatureInvalid, + }, + }, + { + given: given{ + name: "should fail, got RSA but found no key", + tokenString: "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + }, + expected: expected{ + keyFunc: emptyKeyFunc, + claims: MapClaims{"foo": "bar"}, + valid: false, + errors: ValidationErrorSignatureInvalid, + }, + }, + { + given: given{ + name: "key does not match", + tokenString: "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.eyJmb28iOiJiYXIifQ.FhkiHkoESI_cG3NPigFrxEk9Z60_oXrOT2vGm9Pn6RDgYNovYORQmmA0zs1AoAOf09ly2Nx2YAg6ABqAYga1AcMFkJljwxTT5fYphTuqpWdy4BELeSYJx5Ty2gmr8e7RonuUztrdD5WfPqLKMm1Ozp_T6zALpRmwTIW0QPnaBXaQD90FplAg46Iy1UlDKr-Eupy0i5SLch5Q-p2ZpaL_5fnTIUDlxC3pWhJTyx_71qDI-mAA_5lE_VdroOeflG56sSmDxopPEG3bFlSu1eowyBfxtu0_CuVd-M42RU75Zc4Gsj6uV77MBtbMrf4_7M_NUTSgoIF3fRqxrj0NzihIBg", + }, + expected: expected{ + keyFunc: randomKey, + claims: MapClaims{"foo": "bar"}, + valid: false, + errors: ValidationErrorSignatureInvalid, + }, + }, + { + given: given{ + name: "used before issued", + generate: &generate{ + claims: MapClaims{"foo": "bar", "iat": float64(time.Now().Unix() + 500)}, + }, + }, + expected: expected{ + keyFunc: defaultKeyFunc, + claims: MapClaims{"foo": "bar", "iat": float64(time.Now().Unix() + 500)}, + valid: false, + errors: ValidationErrorIssuedAt, + }, + }, + { + given: given{ + name: "valid ECDSA signing method", + generate: &generate{ + claims: MapClaims{"foo": "bar"}, + signingKey: defaultES256PrivateKey, + method: jose.ES256, + }, + }, + expected: expected{ + keyFunc: publicECDSAKey, + claims: MapClaims{"foo": "bar"}, + valid: true, + errors: 0, + }, + }, + { + given: given{ + name: "should pass, valid NONE signing method", + generate: &generate{ + claims: MapClaims{"foo": "bar"}, + signingKey: UnsafeAllowNoneSignatureType, + method: SigningMethodNone, + }, + }, + expected: expected{ + keyFunc: noneKey, + claims: MapClaims{"foo": "bar"}, + valid: true, + errors: 0, + }, + }, + { + given: given{ + name: "should fail, expected RS256 but got NONE", + generate: &generate{ + claims: MapClaims{"foo": "bar"}, + signingKey: UnsafeAllowNoneSignatureType, + method: SigningMethodNone, + }, + }, + expected: expected{ + keyFunc: defaultKeyFunc, + claims: MapClaims{"foo": "bar"}, + valid: false, + errors: ValidationErrorSignatureInvalid, + }, + }, + { + given: given{ + name: "should fail, expected ECDSA but got NONE", + generate: &generate{ + claims: MapClaims{"foo": "bar"}, + signingKey: UnsafeAllowNoneSignatureType, + method: SigningMethodNone, + }, + }, + expected: expected{ + keyFunc: publicECDSAKey, + claims: MapClaims{"foo": "bar"}, + valid: false, + errors: ValidationErrorSignatureInvalid, + }, + }, + } + + // Iterate over test data set and run tests + for _, data := range jwtTestData { + t.Run(data.name, func(t *testing.T) { + if data.generate != nil { + signingKey := data.generate.signingKey + method := data.generate.method + if signingKey == nil { + // use test defaults + signingKey = defaultSigningKey + method = jose.RS256 + } + data.tokenString = makeSampleToken(data.generate.claims, method, signingKey) + } + + // Parse the token + var token *Token + var err error + + // Figure out correct claims type + token, err = ParseWithClaims(data.tokenString, MapClaims{}, data.keyFunc) + // Verify result matches expectation + assert.EqualValues(t, data.claims, token.Claims) + if data.valid && err != nil { + t.Errorf("[%v] Error while verifying token: %T:%v", data.name, err, err) + } + + if !data.valid && err == nil { + t.Errorf("[%v] Invalid token passed validation", data.name) + } + + if (err == nil && !token.Valid()) || (err != nil && token.Valid()) { + t.Errorf("[%v] Inconsistent behavior between returned error and token.Valid", data.name) + } + + if data.errors != 0 { + if err == nil { + t.Errorf("[%v] Expecting error. Didn't get one.", data.name) + } else { + + ve := err.(*ValidationError) + // compare the bitfield part of the error + if e := ve.Errors; e != data.errors { + t.Errorf("[%v] Errors don't match expectation. %v != %v", data.name, e, data.errors) + } + + if err.Error() == keyFuncError.Error() && ve.Inner != keyFuncError { + t.Errorf("[%v] Inner error does not match expectation. %v != %v", data.name, ve.Inner, keyFuncError) + } + } + } + }) + } +} + +func makeSampleToken(c MapClaims, m jose.SignatureAlgorithm, key interface{}) string { + token := NewWithClaims(m, c) + s, e := token.SignedString(key) + + if e != nil { + panic(e.Error()) + } + + return s +} + +func parseRSAPublicKeyFromPEM(key []byte) *rsa.PublicKey { + var err error + + // Parse PEM block + var block *pem.Block + if block, _ = pem.Decode(key); block == nil { + panic("not possible to decode") + } + + // Parse the key + var parsedKey interface{} + if parsedKey, err = x509.ParsePKIXPublicKey(block.Bytes); err != nil { + if cert, err := x509.ParseCertificate(block.Bytes); err == nil { + parsedKey = cert.PublicKey + } else { + panic(err) + } + } + + var pkey *rsa.PublicKey + var ok bool + if pkey, ok = parsedKey.(*rsa.PublicKey); !ok { + panic("not an *rsa.PublicKey") + } + + return pkey +} + +func parseRSAPrivateKeyFromPEM(key []byte) *rsa.PrivateKey { + var err error + + // Parse PEM block + var block *pem.Block + if block, _ = pem.Decode(key); block == nil { + panic("unable to decode") + } + + var parsedKey interface{} + if parsedKey, err = x509.ParsePKCS1PrivateKey(block.Bytes); err != nil { + if parsedKey, err = x509.ParsePKCS8PrivateKey(block.Bytes); err != nil { + panic(err) + } + } + + var pkey *rsa.PrivateKey + var ok bool + if pkey, ok = parsedKey.(*rsa.PrivateKey); !ok { + panic("not an rsa private key") + } + + return pkey +} + +var ( + defaultPubKeyPEM = []byte(` +-----BEGIN PUBLIC KEY----- +MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA4f5wg5l2hKsTeNem/V41 +fGnJm6gOdrj8ym3rFkEU/wT8RDtnSgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7 +mCpz9Er5qLaMXJwZxzHzAahlfA0icqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBp +HssPnpYGIn20ZZuNlX2BrClciHhCPUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2 +XrHhR+1DcKJzQBSTAGnpYVaqpsARap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3b +ODIRe1AuTyHceAbewn8b462yEWKARdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy +7wIDAQAB +-----END PUBLIC KEY-----`) + defaultPrivateKeyPEM = []byte(` +-----BEGIN RSA PRIVATE KEY----- +MIIEowIBAAKCAQEA4f5wg5l2hKsTeNem/V41fGnJm6gOdrj8ym3rFkEU/wT8RDtn +SgFEZOQpHEgQ7JL38xUfU0Y3g6aYw9QT0hJ7mCpz9Er5qLaMXJwZxzHzAahlfA0i +cqabvJOMvQtzD6uQv6wPEyZtDTWiQi9AXwBpHssPnpYGIn20ZZuNlX2BrClciHhC +PUIIZOQn/MmqTD31jSyjoQoV7MhhMTATKJx2XrHhR+1DcKJzQBSTAGnpYVaqpsAR +ap+nwRipr3nUTuxyGohBTSmjJ2usSeQXHI3bODIRe1AuTyHceAbewn8b462yEWKA +Rdpd9AjQW5SIVPfdsz5B6GlYQ5LdYKtznTuy7wIDAQABAoIBAQCwia1k7+2oZ2d3 +n6agCAbqIE1QXfCmh41ZqJHbOY3oRQG3X1wpcGH4Gk+O+zDVTV2JszdcOt7E5dAy +MaomETAhRxB7hlIOnEN7WKm+dGNrKRvV0wDU5ReFMRHg31/Lnu8c+5BvGjZX+ky9 +POIhFFYJqwCRlopGSUIxmVj5rSgtzk3iWOQXr+ah1bjEXvlxDOWkHN6YfpV5ThdE +KdBIPGEVqa63r9n2h+qazKrtiRqJqGnOrHzOECYbRFYhexsNFz7YT02xdfSHn7gM +IvabDDP/Qp0PjE1jdouiMaFHYnLBbgvlnZW9yuVf/rpXTUq/njxIXMmvmEyyvSDn +FcFikB8pAoGBAPF77hK4m3/rdGT7X8a/gwvZ2R121aBcdPwEaUhvj/36dx596zvY +mEOjrWfZhF083/nYWE2kVquj2wjs+otCLfifEEgXcVPTnEOPO9Zg3uNSL0nNQghj +FuD3iGLTUBCtM66oTe0jLSslHe8gLGEQqyMzHOzYxNqibxcOZIe8Qt0NAoGBAO+U +I5+XWjWEgDmvyC3TrOSf/KCGjtu0TSv30ipv27bDLMrpvPmD/5lpptTFwcxvVhCs +2b+chCjlghFSWFbBULBrfci2FtliClOVMYrlNBdUSJhf3aYSG2Doe6Bgt1n2CpNn +/iu37Y3NfemZBJA7hNl4dYe+f+uzM87cdQ214+jrAoGAXA0XxX8ll2+ToOLJsaNT +OvNB9h9Uc5qK5X5w+7G7O998BN2PC/MWp8H+2fVqpXgNENpNXttkRm1hk1dych86 +EunfdPuqsX+as44oCyJGFHVBnWpm33eWQw9YqANRI+pCJzP08I5WK3osnPiwshd+ +hR54yjgfYhBFNI7B95PmEQkCgYBzFSz7h1+s34Ycr8SvxsOBWxymG5zaCsUbPsL0 +4aCgLScCHb9J+E86aVbbVFdglYa5Id7DPTL61ixhl7WZjujspeXZGSbmq0Kcnckb +mDgqkLECiOJW2NHP/j0McAkDLL4tysF8TLDO8gvuvzNC+WQ6drO2ThrypLVZQ+ry +eBIPmwKBgEZxhqa0gVvHQG/7Od69KWj4eJP28kq13RhKay8JOoN0vPmspXJo1HY3 +CKuHRG+AP579dncdUnOMvfXOtkdM4vk0+hWASBQzM9xzVcztCa+koAugjVaLS9A+ +9uQoqEeVNTckxx0S2bYevRy7hGQmUJTyQm3j1zEUR5jpdbL83Fbq +-----END RSA PRIVATE KEY-----`) +) diff --git a/token/jwt/validation_error.go b/token/jwt/validation_error.go new file mode 100644 index 000000000..80cd50ee3 --- /dev/null +++ b/token/jwt/validation_error.go @@ -0,0 +1,44 @@ +package jwt + +// Validation provides a backwards compatible error definition +// from `jwt-go` to `go-jose`. +// The sourcecode was taken from https://github.com/dgrijalva/jwt-go/blob/master/errors.go +// +// > The errors that might occur when parsing and validating a token +const ( + ValidationErrorMalformed uint32 = 1 << iota // Token is malformed + ValidationErrorUnverifiable // Token could not be verified because of signing problems + ValidationErrorSignatureInvalid // Signature validation failed + + // Standard Claim validation errors + ValidationErrorAudience // AUD validation failed + ValidationErrorExpired // EXP validation failed + ValidationErrorIssuedAt // IAT validation failed + ValidationErrorIssuer // ISS validation failed + ValidationErrorNotValidYet // NBF validation failed + ValidationErrorId // JTI validation failed + ValidationErrorClaimsInvalid // Generic claims validation error +) + +// The error from Parse if token is not valid +type ValidationError struct { + Inner error // stores the error returned by external dependencies, i.e.: KeyFunc + Errors uint32 // bitfield. see ValidationError... constants + text string // errors that do not have a valid error just have text +} + +// Validation error is an error type +func (e ValidationError) Error() string { + if e.Inner != nil { + return e.Inner.Error() + } else if e.text != "" { + return e.text + } else { + return "token is invalid" + } +} + +// No errors +func (e *ValidationError) valid() bool { + return e.Errors == 0 +}