Skip to content

Commit

Permalink
migrate from jwt-go to fosite/token/jwt
Browse files Browse the repository at this point in the history
  • Loading branch information
narg95 committed May 10, 2021
1 parent bc865c0 commit 92100e9
Show file tree
Hide file tree
Showing 13 changed files with 41 additions and 32 deletions.
9 changes: 2 additions & 7 deletions consent/strategy_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ import (

"github.com/ory/x/sqlcon"

jwtgo "github.com/dgrijalva/jwt-go"
"github.com/gorilla/sessions"
jwtgo "github.com/ory/fosite/token/jwt"
"github.com/pborman/uuid"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
Expand Down Expand Up @@ -191,12 +191,7 @@ func (s *DefaultStrategy) getIDTokenHintClaims(ctx context.Context, idTokenHint
return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithHint(err.Error()))
}

claims, ok := token.Claims.(jwtgo.MapClaims)
if !ok {
return nil, errorsx.WithStack(fosite.ErrInvalidRequest.WithHint("Failed to validate OpenID Connect request as decoding id token from id_token_hint to jwt.MapClaims failed."))
}

return claims, nil
return token.Claims, nil
}

func (s *DefaultStrategy) getSubjectFromIDTokenHint(ctx context.Context, idTokenHint string) (string, error) {
Expand Down
4 changes: 2 additions & 2 deletions consent/strategy_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import (
"net/http/httptest"
"testing"

jwtgo "github.com/dgrijalva/jwt-go"
jwtgo "github.com/ory/fosite/token/jwt"
"github.com/stretchr/testify/require"

"github.com/ory/fosite/token/jwt"
Expand Down Expand Up @@ -122,7 +122,7 @@ func newAuthCookieJar(t *testing.T, reg driver.Registry, u, sessionID string) ht
return cj
}

func genIDToken(t *testing.T, reg driver.Registry, c jwtgo.Claims) string {
func genIDToken(t *testing.T, reg driver.Registry, c jwtgo.MapClaims) string {
r, _, err := reg.OpenIDJWTStrategy().Generate(context.TODO(), c, jwt.NewHeaders())
require.NoError(t, err)
return r
Expand Down
2 changes: 1 addition & 1 deletion consent/strategy_logout_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (
"testing"
"time"

jwtgo "github.com/dgrijalva/jwt-go"
jwtgo "github.com/ory/fosite/token/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ require (
github.com/DataDog/datadog-go v4.6.0+incompatible // indirect
github.com/cenkalti/backoff/v3 v3.0.0
github.com/containerd/containerd v1.4.4 // indirect
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/dgrijalva/jwt-go v3.2.0+incompatible // indirect
github.com/evanphx/json-patch v0.5.2
github.com/go-bindata/go-bindata v3.1.1+incompatible
github.com/go-openapi/errors v0.20.0
Expand Down
4 changes: 2 additions & 2 deletions internal/testhelpers/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"testing"
"time"

djwt "github.com/dgrijalva/jwt-go"
djwt "github.com/ory/fosite/token/jwt"
"github.com/stretchr/testify/assert"

"github.com/ory/fosite/token/jwt"
Expand Down Expand Up @@ -80,7 +80,7 @@ func DecodeIDToken(t *testing.T, token *oauth2.Token) gjson.Result {
require.True(t, ok)
assert.NotEmpty(t, idt)

body, err := djwt.DecodeSegment(strings.Split(idt, ".")[1])
body, err := x.DecodeSegment(strings.Split(idt, ".")[1])
require.NoError(t, err)

return gjson.ParseBytes(body)
Expand Down
4 changes: 2 additions & 2 deletions jwk/jwt_strategy.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import (

"github.com/ory/hydra/driver/config"

jwt2 "github.com/dgrijalva/jwt-go"
jwt2 "github.com/ory/fosite/token/jwt"
"github.com/pkg/errors"

"github.com/ory/fosite/token/jwt"
Expand Down Expand Up @@ -75,7 +75,7 @@ func (j *RS256JWTStrategy) GetSignature(ctx context.Context, token string) (stri
return j.RS256JWTStrategy.GetSignature(ctx, token)
}

func (j *RS256JWTStrategy) Generate(ctx context.Context, claims jwt2.Claims, header jwt.Mapper) (string, string, error) {
func (j *RS256JWTStrategy) Generate(ctx context.Context, claims jwt2.MapClaims, header jwt.Mapper) (string, string, error) {
if err := j.refresh(ctx); err != nil {
return "", "", err
}
Expand Down
2 changes: 1 addition & 1 deletion jwk/jwt_strategy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import (

"github.com/ory/hydra/internal"

jwt2 "github.com/dgrijalva/jwt-go"
jwt2 "github.com/ory/fosite/token/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

Expand Down
2 changes: 1 addition & 1 deletion oauth2/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ import (

"github.com/ory/x/errorsx"

jwt2 "github.com/dgrijalva/jwt-go"
"github.com/julienschmidt/httprouter"
jwt2 "github.com/ory/fosite/token/jwt"
"github.com/pkg/errors"

"github.com/ory/fosite"
Expand Down
8 changes: 4 additions & 4 deletions oauth2/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ import (
"github.com/ory/hydra/internal"
"github.com/ory/x/urlx"

jwt2 "github.com/dgrijalva/jwt-go"
"github.com/golang/mock/gomock"
jwt2 "github.com/ory/fosite/token/jwt"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -360,9 +360,9 @@ func TestUserinfo(t *testing.T) {
return jwk.MustRSAPublic(key), nil
})
require.NoError(t, err)
assert.EqualValues(t, "alice", claims.Claims.(jwt2.MapClaims)["sub"])
assert.EqualValues(t, []interface{}{"foobar-client"}, claims.Claims.(jwt2.MapClaims)["aud"], "%#v", claims.Claims)
assert.NotEmpty(t, claims.Claims.(jwt2.MapClaims)["jti"])
assert.EqualValues(t, "alice", claims.Claims["sub"])
assert.EqualValues(t, []interface{}{"foobar-client"}, claims.Claims["aud"], "%#v", claims.Claims)
assert.NotEmpty(t, claims.Claims["jti"])
},
},
} {
Expand Down
11 changes: 5 additions & 6 deletions oauth2/oauth2_auth_code_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ import (
"github.com/ory/hydra/client"
"github.com/ory/hydra/internal/testhelpers"

djwt "github.com/dgrijalva/jwt-go"
"github.com/julienschmidt/httprouter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -201,7 +200,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
require.True(t, ok)
assert.NotEmpty(t, idt)

body, err := djwt.DecodeSegment(strings.Split(idt, ".")[1])
body, err := x.DecodeSegment(strings.Split(idt, ".")[1])
require.NoError(t, err)

claims := gjson.ParseBytes(body)
Expand Down Expand Up @@ -241,7 +240,7 @@ func TestAuthCodeWithDefaultStrategy(t *testing.T) {
}
require.Len(t, parts, 3)

body, err := djwt.DecodeSegment(parts[1])
body, err := x.DecodeSegment(parts[1])
require.NoError(t, err)

i := gjson.ParseBytes(body)
Expand Down Expand Up @@ -658,7 +657,7 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
return
}

body, err := djwt.DecodeSegment(strings.Split(token, ".")[1])
body, err := x.DecodeSegment(strings.Split(token, ".")[1])
require.NoError(t, err)

data := map[string]interface{}{}
Expand Down Expand Up @@ -867,13 +866,13 @@ func TestAuthCodeWithMockStrategy(t *testing.T) {
t.Skip()
}

body, err := djwt.DecodeSegment(strings.Split(token.AccessToken, ".")[1])
body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1])
require.NoError(t, err)

origPayload := map[string]interface{}{}
require.NoError(t, json.Unmarshal(body, &origPayload))

body, err = djwt.DecodeSegment(strings.Split(refreshedToken.AccessToken, ".")[1])
body, err = x.DecodeSegment(strings.Split(refreshedToken.AccessToken, ".")[1])
require.NoError(t, err)

refreshedPayload := map[string]interface{}{}
Expand Down
4 changes: 2 additions & 2 deletions oauth2/oauth2_client_credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/google/uuid"
"github.com/tidwall/gjson"

"github.com/dgrijalva/jwt-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
goauth2 "golang.org/x/oauth2"
Expand All @@ -41,6 +40,7 @@ import (
hc "github.com/ory/hydra/client"
"github.com/ory/hydra/driver/config"
"github.com/ory/hydra/internal"
"github.com/ory/hydra/x"
)

func TestClientCredentials(t *testing.T) {
Expand Down Expand Up @@ -108,7 +108,7 @@ func TestClientCredentials(t *testing.T) {
return
}

body, err := jwt.DecodeSegment(strings.Split(token.AccessToken, ".")[1])
body, err := x.DecodeSegment(strings.Split(token.AccessToken, ".")[1])
require.NoError(t, err)

jwtClaims := gjson.ParseBytes(body)
Expand Down
6 changes: 3 additions & 3 deletions test/mock-client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ import (
"strings"
"time"

"github.com/dgrijalva/jwt-go"
"golang.org/x/oauth2"

hydra "github.com/ory/hydra/internal/httpclient/client"
"github.com/ory/hydra/internal/httpclient/client/admin"
"github.com/ory/hydra/x"
"github.com/ory/x/cmdx"
"github.com/ory/x/urlx"
)
Expand Down Expand Up @@ -173,7 +173,7 @@ func checkTokenResponse(token oauth2token) {
log.Fatalf("JWT Access Token does not seem to have three parts: %d - %+v - %v", len(parts), token, parts)
}

payload, err := jwt.DecodeSegment(parts[1])
payload, err := x.DecodeSegment(parts[1])
if err != nil {
log.Fatalf("Unable to decode id token segment: %s", err)
}
Expand Down Expand Up @@ -221,7 +221,7 @@ func checkTokenResponse(token oauth2token) {
log.Fatalf("ID Token does not seem to have three parts: %d - %+v - %v", len(parts), token, parts)
}

payload, err := jwt.DecodeSegment(parts[1])
payload, err := x.DecodeSegment(parts[1])
if err != nil {
log.Fatalf("Unable to decode id token segment: %s", err)
}
Expand Down
15 changes: 15 additions & 0 deletions x/jwt.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
package x

import (
"encoding/base64"
"strings"
)

// Decode JWT specific base64url encoding with padding stripped
func DecodeSegment(seg string) ([]byte, error) {
if l := len(seg) % 4; l > 0 {
seg += strings.Repeat("=", 4-l)
}

return base64.URLEncoding.DecodeString(seg)
}

0 comments on commit 92100e9

Please sign in to comment.