Skip to content

Commit

Permalink
feat: transit from jwt-go to go-jose (ory#593)
Browse files Browse the repository at this point in the history
Closes ory#514

Co-authored-by: hackerman <3372410+aeneasr@users.noreply.github.com>
  • Loading branch information
narg95 and aeneasr committed May 22, 2021
1 parent 40e9171 commit 03b22f6
Show file tree
Hide file tree
Showing 23 changed files with 1,216 additions and 205 deletions.
21 changes: 9 additions & 12 deletions authorize_request_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@ import (
"net/http"
"strings"

"github.com/ory/fosite/token/jwt"
"github.com/ory/x/errorsx"
"gopkg.in/square/go-jose.v2"

jwt "github.com/dgrijalva/jwt-go"
"github.com/pkg/errors"

"github.com/ory/go-convenience/stringslice"
Expand Down Expand Up @@ -101,7 +102,7 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(request *Aut
assertion = string(body)
}

token, err := jwt.ParseWithClaims(assertion, new(jwt.MapClaims), func(t *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(assertion, jwt.MapClaims{}, func(t *jwt.Token) (interface{}, error) {
// request_object_signing_alg - OPTIONAL.
// JWS [JWS] alg algorithm [JWA] that MUST be used for signing Request Objects sent to the OP. All Request Objects from this Client MUST be rejected,
// if not signed with this algorithm. Request Objects are described in Section 6.1 of OpenID Connect Core 1.0 [OpenID.Core]. This algorithm MUST
Expand All @@ -115,22 +116,22 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(request *Aut
return jwt.UnsafeAllowNoneSignatureType, nil
}

switch t.Method.(type) {
case *jwt.SigningMethodRSA:
switch t.Method {
case jose.RS256, jose.RS384, jose.RS512:
key, err := f.findClientPublicJWK(oidcClient, t, true)
if err != nil {
return nil, wrapSigningKeyFailure(
ErrInvalidRequestObject.WithHint("Unable to retrieve RSA signing key from OAuth 2.0 Client."), err)
}
return key, nil
case *jwt.SigningMethodECDSA:
case jose.ES256, jose.ES384, jose.ES512:
key, err := f.findClientPublicJWK(oidcClient, t, false)
if err != nil {
return nil, wrapSigningKeyFailure(
ErrInvalidRequestObject.WithHint("Unable to retrieve ECDSA signing key from OAuth 2.0 Client."), err)
}
return key, nil
case *jwt.SigningMethodRSAPSS:
case jose.PS256, jose.PS384, jose.PS512:
key, err := f.findClientPublicJWK(oidcClient, t, true)
if err != nil {
return nil, wrapSigningKeyFailure(
Expand All @@ -155,12 +156,8 @@ func (f *Fosite) authorizeRequestParametersFromOpenIDConnectRequest(request *Aut
return errorsx.WithStack(ErrInvalidRequestObject.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithWrap(err).WithDebug(err.Error()))
}

claims, ok := token.Claims.(*jwt.MapClaims)
if !ok {
return errorsx.WithStack(ErrInvalidRequestObject.WithHint("Unable to type assert claims from request object.").WithDebugf(`Got claims of type %T but expected type '*jwt.MapClaims'.`, token.Claims))
}

for k, v := range *claims {
claims := token.Claims
for k, v := range claims {
request.Form.Set(k, fmt.Sprintf("%s", v))
}

Expand Down
15 changes: 10 additions & 5 deletions authorize_request_handler_oidc_request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ import (

"github.com/pkg/errors"

jwt "github.com/dgrijalva/jwt-go"
"github.com/ory/fosite/token/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
jose "gopkg.in/square/go-jose.v2"
)

func mustGenerateAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token := jwt.NewWithClaims(jose.RS256, claims)
if kid != "" {
token.Header["kid"] = kid
}
Expand All @@ -50,7 +50,7 @@ func mustGenerateAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateK
}

func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token := jwt.NewWithClaims(jose.HS256, claims)
tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd"))
require.NoError(t, err)
return tokenString
Expand Down Expand Up @@ -217,10 +217,15 @@ func TestAuthorizeRequestParametersFromOpenIDConnectRequest(t *testing.T) {
if tc.expectErrReason != "" {
real := new(RFC6749Error)
require.True(t, errors.As(err, &real))
assert.EqualValues(t, real.Reason(), tc.expectErrReason)
assert.EqualValues(t, tc.expectErrReason, real.Reason())
}
} else {
require.NoError(t, err)
if err != nil {
real := new(RFC6749Error)
errors.As(err, &real)
require.NoErrorf(t, err, "Hint: %v\nDebug:%v", real.HintField, real.DebugField)
}
require.NoErrorf(t, err, "%+v", err)
require.Equal(t, len(tc.expectForm), len(req.Form))
for k, v := range tc.expectForm {
assert.EqualValues(t, v, req.Form[k])
Expand Down
37 changes: 16 additions & 21 deletions client_authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import (

"github.com/ory/x/errorsx"

jwt "github.com/dgrijalva/jwt-go"
"github.com/ory/fosite/token/jwt"
"github.com/pkg/errors"
jose "gopkg.in/square/go-jose.v2"
)
Expand Down Expand Up @@ -90,17 +90,16 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
var clientID string
var client Client

token, err := jwt.ParseWithClaims(assertion, new(jwt.MapClaims), func(t *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(assertion, jwt.MapClaims{}, func(t *jwt.Token) (interface{}, error) {
var err error
clientID, _, err = clientCredentialsFromRequestBody(form, false)
if err != nil {
return nil, err
}

if clientID == "" {
if claims, ok := t.Claims.(*jwt.MapClaims); !ok {
return nil, errorsx.WithStack(ErrRequestUnauthorized.WithHint("Unable to type assert claims from client_assertion.").WithDebugf(`Expected claims to be of type '*jwt.MapClaims' but got '%T'.`, t.Claims))
} else if sub, ok := (*claims)["sub"].(string); !ok {
claims := t.Claims
if sub, ok := claims["sub"].(string); !ok {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("The claim 'sub' from the client_assertion JSON Web Token is undefined."))
} else {
clientID = sub
Expand Down Expand Up @@ -135,18 +134,18 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
if oidcClient.GetTokenEndpointAuthSigningAlgorithm() != fmt.Sprintf("%s", t.Header["alg"]) {
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' uses signing algorithm '%s' but the requested OAuth 2.0 Client enforces signing algorithm '%s'.", t.Header["alg"], oidcClient.GetTokenEndpointAuthSigningAlgorithm()))
}

if _, ok := t.Method.(*jwt.SigningMethodRSA); ok {
switch t.Method {
case jose.RS256, jose.RS384, jose.RS512:
return f.findClientPublicJWK(oidcClient, t, true)
} else if _, ok := t.Method.(*jwt.SigningMethodECDSA); ok {
case jose.ES256, jose.ES384, jose.ES512:
return f.findClientPublicJWK(oidcClient, t, false)
} else if _, ok := t.Method.(*jwt.SigningMethodRSAPSS); ok {
case jose.PS256, jose.PS384, jose.PS512:
return f.findClientPublicJWK(oidcClient, t, true)
} else if _, ok := t.Method.(*jwt.SigningMethodHMAC); ok {
case jose.HS256, jose.HS384, jose.HS512:
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("This authorization server does not support client authentication method 'client_secret_jwt'."))
default:
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' request parameter uses unsupported signing algorithm '%s'.", t.Header["alg"]))
}

return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("The 'client_assertion' request parameter uses unsupported signing algorithm '%s'.", t.Header["alg"]))
})
if err != nil {
// Do not re-process already enhanced errors
Expand All @@ -162,19 +161,15 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to verify the request object because its claims could not be validated, check if the expiry time is set correctly.").WithWrap(err).WithDebug(err.Error()))
}

claims, ok := token.Claims.(*jwt.MapClaims)
if !ok {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Unable to type assert claims from request parameter 'client_assertion'.").WithDebugf("Got claims of type %T but expected type '*jwt.MapClaims'.", token.Claims))
}

claims := token.Claims
var jti string
if !claims.VerifyIssuer(clientID, true) {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'iss' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client."))
} else if f.TokenURL == "" {
return nil, errorsx.WithStack(ErrMisconfiguration.WithHint("The authorization server's token endpoint URL has not been set."))
} else if sub, ok := (*claims)["sub"].(string); !ok || sub != clientID {
} else if sub, ok := claims["sub"].(string); !ok || sub != clientID {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'sub' from 'client_assertion' must match the 'client_id' of the OAuth 2.0 Client."))
} else if jti, ok = (*claims)["jti"].(string); !ok || len(jti) == 0 {
} else if jti, ok = claims["jti"].(string); !ok || len(jti) == 0 {
return nil, errorsx.WithStack(ErrInvalidClient.WithHint("Claim 'jti' from 'client_assertion' must be set but is not."))
} else if f.Store.ClientAssertionJWTValid(ctx, jti) != nil {
return nil, errorsx.WithStack(ErrJTIKnown.WithHint("Claim 'jti' from 'client_assertion' MUST only be used once."))
Expand All @@ -183,7 +178,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
// type conversion according to jwt.MapClaims.VerifyExpiresAt
var expiry int64
err = nil
switch exp := (*claims)["exp"].(type) {
switch exp := claims["exp"].(type) {
case float64:
expiry = int64(exp)
case json.Number:
Expand All @@ -199,7 +194,7 @@ func (f *Fosite) DefaultClientAuthenticationStrategy(ctx context.Context, r *htt
return nil, err
}

if auds, ok := (*claims)["aud"].([]interface{}); !ok {
if auds, ok := claims["aud"].([]interface{}); !ok {
if !claims.VerifyAudience(f.TokenURL, true) {
return nil, errorsx.WithStack(ErrInvalidClient.WithHintf("Claim 'audience' from 'client_assertion' must match the authorization server's token endpoint '%s'.", f.TokenURL))
}
Expand Down
9 changes: 5 additions & 4 deletions client_authentication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import (
"testing"
"time"

jwt "github.com/dgrijalva/jwt-go"
"github.com/ory/fosite/token/jwt"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand All @@ -46,23 +46,23 @@ import (
)

func mustGenerateRSAAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token := jwt.NewWithClaims(jose.RS256, claims)
token.Header["kid"] = kid
tokenString, err := token.SignedString(key)
require.NoError(t, err)
return tokenString
}

func mustGenerateECDSAAssertion(t *testing.T, claims jwt.MapClaims, key *ecdsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jwt.SigningMethodES256, claims)
token := jwt.NewWithClaims(jose.ES256, claims)
token.Header["kid"] = kid
tokenString, err := token.SignedString(key)
require.NoError(t, err)
return tokenString
}

func mustGenerateHSAssertion(t *testing.T, claims jwt.MapClaims, key *rsa.PrivateKey, kid string) string {
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token := jwt.NewWithClaims(jose.HS256, claims)
tokenString, err := token.SignedString([]byte("aaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbcccccccccccccccccccccddddddddddddddddddddddd"))
require.NoError(t, err)
return tokenString
Expand Down Expand Up @@ -503,6 +503,7 @@ func TestAuthenticateClient(t *testing.T) {
t.Logf("Error is: %s", validationError.Inner)
} else if errors.As(err, &rfcError) {
t.Logf("DebugField is: %s", rfcError.DebugField)
t.Logf("HintField is: %s", rfcError.HintField)
}
}
require.NoError(t, err)
Expand Down
2 changes: 0 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ replace github.com/gogo/protobuf => github.com/gogo/protobuf v1.3.2
require (
github.com/asaskevich/govalidator v0.0.0-20200428143746-21a406dcc535
github.com/dgraph-io/ristretto v0.0.3 // indirect
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/form3tech-oss/jwt-go v3.2.2+incompatible // indirect
github.com/golang/mock v1.4.4
github.com/gorilla/mux v1.7.3
github.com/gorilla/websocket v1.4.2
Expand Down
3 changes: 0 additions & 3 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,7 @@ github.com/fatih/structs v1.0.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga
github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/fogleman/gg v1.3.0/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
github.com/form3tech-oss/jwt-go v3.2.1+incompatible h1:xdtqez379uWVJ9P3qQMX8W+F/nqsTdUvyMZB36tnacA=
github.com/form3tech-oss/jwt-go v3.2.1+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible h1:TcekIExNqud5crz4xD2pavyTgWiPvpYe4Xau31I0PRk=
github.com/form3tech-oss/jwt-go v3.2.2+incompatible/go.mod h1:pbq4aXjuKjdthFRnoDwaVPLA+WlJuPGy+QneDUgJi2k=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4=
github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ=
Expand Down
6 changes: 2 additions & 4 deletions handler/oauth2/introspector_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ import (
"context"
"time"

jwtx "github.com/dgrijalva/jwt-go"

"github.com/ory/fosite"
"github.com/ory/fosite/token/jwt"
)
Expand All @@ -37,8 +35,8 @@ type StatelessJWTValidator struct {
}

// AccessTokenJWTToRequest tries to reconstruct fosite.Request from a JWT.
func AccessTokenJWTToRequest(token *jwtx.Token) fosite.Requester {
mapClaims := token.Claims.(jwtx.MapClaims)
func AccessTokenJWTToRequest(token *jwt.Token) fosite.Requester {
mapClaims := token.Claims
claims := jwt.JWTClaims{}
claims.FromMapClaims(mapClaims)

Expand Down
25 changes: 12 additions & 13 deletions handler/oauth2/strategy_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import (

"github.com/ory/x/errorsx"

jwtx "github.com/dgrijalva/jwt-go"
"github.com/pkg/errors"

"github.com/ory/fosite"
Expand Down Expand Up @@ -99,36 +98,36 @@ func (h *DefaultJWTStrategy) ValidateAuthorizeCode(ctx context.Context, req fosi
return h.HMACSHAStrategy.ValidateAuthorizeCode(ctx, req, token)
}

func validate(ctx context.Context, jwtStrategy jwt.JWTStrategy, token string) (t *jwtx.Token, err error) {
func validate(ctx context.Context, jwtStrategy jwt.JWTStrategy, token string) (t *jwt.Token, err error) {
t, err = jwtStrategy.Decode(ctx, token)

if err == nil {
err = t.Claims.Valid()
}

if err != nil {
var e *jwtx.ValidationError
var e *jwt.ValidationError
if errors.As(err, &e) {
switch e.Errors {
case jwtx.ValidationErrorMalformed:
case jwt.ValidationErrorMalformed:
err = errorsx.WithStack(fosite.ErrInvalidTokenFormat.WithWrap(err).WithDebug(err.Error()))
case jwtx.ValidationErrorUnverifiable:
case jwt.ValidationErrorUnverifiable:
err = errorsx.WithStack(fosite.ErrTokenSignatureMismatch.WithWrap(err).WithDebug(err.Error()))
case jwtx.ValidationErrorSignatureInvalid:
case jwt.ValidationErrorSignatureInvalid:
err = errorsx.WithStack(fosite.ErrTokenSignatureMismatch.WithWrap(err).WithDebug(err.Error()))
case jwtx.ValidationErrorAudience:
case jwt.ValidationErrorAudience:
err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error()))
case jwtx.ValidationErrorExpired:
case jwt.ValidationErrorExpired:
err = errorsx.WithStack(fosite.ErrTokenExpired.WithWrap(err).WithDebug(err.Error()))
case jwtx.ValidationErrorIssuedAt:
case jwt.ValidationErrorIssuedAt:
err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error()))
case jwtx.ValidationErrorIssuer:
case jwt.ValidationErrorIssuer:
err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error()))
case jwtx.ValidationErrorNotValidYet:
case jwt.ValidationErrorNotValidYet:
err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error()))
case jwtx.ValidationErrorId:
case jwt.ValidationErrorId:
err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error()))
case jwtx.ValidationErrorClaimsInvalid:
case jwt.ValidationErrorClaimsInvalid:
err = errorsx.WithStack(fosite.ErrTokenClaim.WithWrap(err).WithDebug(err.Error()))
default:
err = errorsx.WithStack(fosite.ErrRequestUnauthorized.WithWrap(err).WithDebug(err.Error()))
Expand Down
6 changes: 3 additions & 3 deletions handler/openid/flow_explicit_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import (
"fmt"
"testing"

jwtgo "github.com/dgrijalva/jwt-go"
"github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -124,10 +123,11 @@ func TestExplicit_PopulateTokenEndpointResponse(t *testing.T) {
check: func(t *testing.T, aresp *fosite.AccessResponse) {
assert.NotEmpty(t, aresp.GetExtra("id_token"))
idToken, _ := aresp.GetExtra("id_token").(string)
decodedIdToken, _ := jwtgo.Parse(idToken, func(token *jwtgo.Token) (interface{}, error) {
decodedIdToken, err := jwt.Parse(idToken, func(token *jwt.Token) (interface{}, error) {
return key.PublicKey, nil
})
claims, _ := decodedIdToken.Claims.(jwtgo.MapClaims)
require.NoError(t, err)
claims := decodedIdToken.Claims
assert.NotEmpty(t, claims["at_hash"])
},
},
Expand Down
6 changes: 3 additions & 3 deletions handler/openid/flow_refresh_token_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ package openid
import (
"testing"

jwtgo "github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down Expand Up @@ -164,10 +163,11 @@ func TestOpenIDConnectRefreshHandler_PopulateTokenEndpointResponse(t *testing.T)
check: func(t *testing.T, aresp *fosite.AccessResponse) {
assert.NotEmpty(t, aresp.GetExtra("id_token"))
idToken, _ := aresp.GetExtra("id_token").(string)
decodedIdToken, _ := jwtgo.Parse(idToken, func(token *jwtgo.Token) (interface{}, error) {
decodedIdToken, err := jwt.Parse(idToken, func(token *jwt.Token) (interface{}, error) {
return key.PublicKey, nil
})
claims, _ := decodedIdToken.Claims.(jwtgo.MapClaims)
require.NoError(t, err)
claims := decodedIdToken.Claims
assert.NotEmpty(t, claims["at_hash"])
},
},
Expand Down
Loading

0 comments on commit 03b22f6

Please sign in to comment.