Skip to content

Commit

Permalink
Merge pull request #568 from litriv/issue-562
Browse files Browse the repository at this point in the history
auth/jwt: MapClaims: passing
  • Loading branch information
peterbourgon authored Jul 13, 2017
2 parents c4aa78e + 37eab0a commit b42a850
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
20 changes: 18 additions & 2 deletions auth/jwt/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,27 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai
}
}

// ClaimsFactory is a factory for jwt.Claims.
// Useful in NewParser middleware.
type ClaimsFactory func() jwt.Claims

// MapClaimsFactory is a ClaimsFactory that returns
// an empty jwt.MapClaims.
func MapClaimsFactory() jwt.Claims {
return jwt.MapClaims{}
}

// StandardClaimsFactory is a ClaimsFactory that returns
// an empty jwt.StandardClaims.
func StandardClaimsFactory() jwt.Claims {
return &jwt.StandardClaims{}
}

// NewParser creates a new JWT token parsing middleware, specifying a
// jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser
// adds the resulting claims to endpoint context or returns error on invalid token.
// Particularly useful for servers.
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware {
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims ClaimsFactory) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
// tokenString is stored in the context from the transport handlers.
Expand All @@ -85,7 +101,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims)
// of the token to identify which key to use, but the parsed token
// (head and claims) is provided to the callback, providing
// flexibility.
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, newClaims(), func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if token.Method != method {
return nil, ErrUnexpectedSigningMethod
Expand Down
18 changes: 9 additions & 9 deletions auth/jwt/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestJWTParser(t *testing.T) {
return key, nil
}

parser := NewParser(keys, method, jwt.MapClaims{})(e)
parser := NewParser(keys, method, MapClaimsFactory)(e)

// No Token is passed into the parser
_, err := parser(context.Background(), struct{}{})
Expand All @@ -94,7 +94,7 @@ func TestJWTParser(t *testing.T) {
}

// Invalid Method is used in the parser
badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e)
badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -110,7 +110,7 @@ func TestJWTParser(t *testing.T) {
return []byte("bad"), nil
}

badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e)
badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -134,15 +134,15 @@ func TestJWTParser(t *testing.T) {
}

// Test for malformed token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey)
ctx1, err = parser(ctx, struct{}{})
if want, have := ErrTokenMalformed, err; want != have {
t.Fatalf("Expected %+v, got %+v", want, have)
}

// Test for expired token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100})
token, err := expired.SignedString(key)
if err != nil {
Expand All @@ -155,7 +155,7 @@ func TestJWTParser(t *testing.T) {
}

// Test for not activated token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100})
token, err = notactive.SignedString(key)
if err != nil {
Expand All @@ -168,7 +168,7 @@ func TestJWTParser(t *testing.T) {
}

// test valid standard claims token
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
Expand All @@ -183,7 +183,7 @@ func TestJWTParser(t *testing.T) {
}

// test valid customized claims token
parser = NewParser(keys, method, &customClaims{})(e)
parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
Expand All @@ -204,7 +204,7 @@ func TestJWTParser(t *testing.T) {
func TestIssue562(t *testing.T) {
var (
kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil }
e = NewParser(kf, jwt.SigningMethodHS256, jwt.MapClaims{})(endpoint.Nop)
e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop)
key = JWTTokenContextKey
val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
ctx = context.WithValue(context.Background(), key, val)
Expand Down

0 comments on commit b42a850

Please sign in to comment.