Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

auth/jwt: MapClaims: passing #568

Merged
merged 3 commits into from
Jul 13, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions auth/jwt/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,27 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Clai
}
}

// ClaimsFactory is a factory for jwt.Claims.
// Useful in NewParser middleware.
type ClaimsFactory func() jwt.Claims

// MapClaimsFactory is a ClaimsFactory that returns
// an empty jwt.MapClaims.
func MapClaimsFactory() jwt.Claims {
return jwt.MapClaims{}
}

// StandardClaimsFactory is a ClaimsFactory that returns
// an empty jwt.StandardClaims.
func StandardClaimsFactory() jwt.Claims {
return &jwt.StandardClaims{}
}

// NewParser creates a new JWT token parsing middleware, specifying a
// jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser
// adds the resulting claims to endpoint context or returns error on invalid token.
// Particularly useful for servers.
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware {
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, newClaims ClaimsFactory) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
// tokenString is stored in the context from the transport handlers.
Expand All @@ -85,7 +101,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims)
// of the token to identify which key to use, but the parsed token
// (head and claims) is provided to the callback, providing
// flexibility.
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, newClaims(), func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if token.Method != method {
return nil, ErrUnexpectedSigningMethod
Expand Down
18 changes: 9 additions & 9 deletions auth/jwt/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestJWTParser(t *testing.T) {
return key, nil
}

parser := NewParser(keys, method, jwt.MapClaims{})(e)
parser := NewParser(keys, method, MapClaimsFactory)(e)

// No Token is passed into the parser
_, err := parser(context.Background(), struct{}{})
Expand All @@ -94,7 +94,7 @@ func TestJWTParser(t *testing.T) {
}

// Invalid Method is used in the parser
badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e)
badParser := NewParser(keys, invalidMethod, MapClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -110,7 +110,7 @@ func TestJWTParser(t *testing.T) {
return []byte("bad"), nil
}

badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e)
badParser = NewParser(invalidKeys, method, MapClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -134,15 +134,15 @@ func TestJWTParser(t *testing.T) {
}

// Test for malformed token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, malformedKey)
ctx1, err = parser(ctx, struct{}{})
if want, have := ErrTokenMalformed, err; want != have {
t.Fatalf("Expected %+v, got %+v", want, have)
}

// Test for expired token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
expired := jwt.NewWithClaims(method, jwt.StandardClaims{ExpiresAt: time.Now().Unix() - 100})
token, err := expired.SignedString(key)
if err != nil {
Expand All @@ -155,7 +155,7 @@ func TestJWTParser(t *testing.T) {
}

// Test for not activated token error response
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
notactive := jwt.NewWithClaims(method, jwt.StandardClaims{NotBefore: time.Now().Unix() + 100})
token, err = notactive.SignedString(key)
if err != nil {
Expand All @@ -168,7 +168,7 @@ func TestJWTParser(t *testing.T) {
}

// test valid standard claims token
parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
parser = NewParser(keys, method, StandardClaimsFactory)(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
Expand All @@ -183,7 +183,7 @@ func TestJWTParser(t *testing.T) {
}

// test valid customized claims token
parser = NewParser(keys, method, &customClaims{})(e)
parser = NewParser(keys, method, func() jwt.Claims { return &customClaims{} })(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
Expand All @@ -204,7 +204,7 @@ func TestJWTParser(t *testing.T) {
func TestIssue562(t *testing.T) {
var (
kf = func(token *jwt.Token) (interface{}, error) { return []byte("secret"), nil }
e = NewParser(kf, jwt.SigningMethodHS256, jwt.MapClaims{})(endpoint.Nop)
e = NewParser(kf, jwt.SigningMethodHS256, MapClaimsFactory)(endpoint.Nop)
key = JWTTokenContextKey
val = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
ctx = context.WithValue(context.Background(), key, val)
Expand Down