Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auth/jwt: MapClaims: passing #568

Merged
merged 3 commits into from
Jul 13, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions auth/jwt/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai
}
}

type claimsFactory func() jwt.Claims
Copy link
Member

@peterbourgon peterbourgon Jul 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be exported, ClaimsFactory.


// NewParser creates a new JWT token parsing middleware, specifying a
// jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser
// adds the resulting claims to endpoint context or returns error on invalid token.
// Particularly useful for servers.
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware {
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims claimsFactory) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
// tokenString is stored in the context from the transport handlers.
Expand All @@ -85,7 +87,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims)
// of the token to identify which key to use, but the parsed token
// (head and claims) is provided to the callback, providing
// flexibility.
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, newClaims(), func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if token.Method != method {
return nil, ErrUnexpectedSigningMethod
Expand Down
18 changes: 9 additions & 9 deletions auth/jwt/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestJWTParser(t *testing.T) {
return key, nil
}

parser := NewParser(keys, method, jwt.MapClaims{})(e)
parser := NewParser(keys, method, func() jwt.Claims { return jwt.MapClaims{} })(e)
Copy link
Member

@peterbourgon peterbourgon Jul 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This anonymous func should be provided as a top-level function in the package, e.g.

// MapClaimsFactory is a ClaimsFactory that returns 
// an empty jwt.MapClaims.
func MapClaimsFactory() jwt.Claims {
    return jwt.MapClaims{}
}


// No Token is passed into the parser
_, err := parser(context.Background(), struct{}{})
Expand All @@ -94,7 +94,7 @@ func TestJWTParser(t *testing.T) {
}

// Invalid Method is used in the parser
badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e)
badParser := NewParser(keys, invalidMethod, func() jwt.Claims { return jwt.MapClaims{} })(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -110,7 +110,7 @@ func TestJWTParser(t *testing.T) {
return []byte("bad"), nil
}

badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e)
badParser = NewParser(invalidKeys, method, func() jwt.Claims { return jwt.MapClaims{} })(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -134,15 +134,15 @@ func TestJWTParser(t *testing.T) {
}

// Test for malformed token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e)
Copy link
Member

@peterbourgon peterbourgon Jul 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, this anonymous func should be provided as a top-level function in the package, e.g.

// StandardClaimsFactory is a ClaimsFactory that returns 
// an empty jwt.StandardClaims.
func StandardClaimsFactory() jwt.Claims {
    return &jwt.StandardClaims{}
}

ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey)
ctx1, err = parser(ctx, struct{}{})
if want, have := ErrTokenMalformed, err; want != have {
t.Fatalf("Expected %+v, got %+v", want, have)
}

// Test for expired token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e)
expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100})
token, err := expired.SignedString(key)
if err != nil {
Expand All @@ -155,7 +155,7 @@ func TestJWTParser(t *testing.T) {
}

// Test for not activated token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e)
notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100})
token, err = notactive.SignedString(key)
if err != nil {
Expand All @@ -168,7 +168,7 @@ func TestJWTParser(t *testing.T) {
}

// test valid standard claims token
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, func() jwt.Claims { return &jwt.StandardClaims{} })(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
Expand All @@ -183,7 +183,7 @@ func TestJWTParser(t *testing.T) {
}

// test valid customized claims token
parser = NewParser(keys, method, &customClaims{})(e)
parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
Expand All @@ -204,7 +204,7 @@ func TestJWTParser(t *testing.T) {
func TestIssue562(t *testing.T) {
var (
kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil }
e = NewParser(kf, jwt.SigningMethodHS256, jwt.MapClaims{})(endpoint.Nop)
e = NewParser(kf, jwt.SigningMethodHS256, func() jwt.Claims { return jwt.MapClaims{} })(endpoint.Nop)
key = JWTTokenContextKey
val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
ctx = context.WithValue(context.Background(), key, val)
Expand Down