Skip to content

Commit

Permalink
Merge pull request #488 from cam-stitt/jwt-claims
Browse files Browse the repository at this point in the history
Add *WithClaims methods to jwt middleware for more advanced usage.
  • Loading branch information
peterbourgon authored Mar 16, 2017
2 parents b6f30a2 + b2b278c commit c9c7219
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 32 deletions.
21 changes: 8 additions & 13 deletions auth/jwt/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,14 @@ var (
ErrUnexpectedSigningMethod = errors.New("unexpected signing method")
)

// Claims is a map of arbitrary claim data.
type Claims map[string]interface{}

// NewSigner creates a new JWT token generating middleware, specifying key ID,
// signing string, signing method and the claims you would like it to contain.
// Tokens are signed with a Key ID header (kid) which is useful for determining
// the key to use for parsing. Particularly useful for clients.
func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims) endpoint.Middleware {
func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
token := jwt.NewWithClaims(method, jwt.MapClaims(claims))
token := jwt.NewWithClaims(method, claims)
token.Header["kid"] = kid

// Sign and get the complete encoded token as a string using the secret
Expand All @@ -70,10 +67,10 @@ func NewSigner(kid string, key []byte, method jwt.SigningMethod, claims Claims)
}

// NewParser creates a new JWT token parsing middleware, specifying a
// jwt.Keyfunc interface and the signing method. NewParser adds the resulting
// claims to endpoint context or returns error on invalid token. Particularly
// useful for servers.
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middleware {
// jwt.Keyfunc interface, the signing method and the claims type to be used. NewParser
// adds the resulting claims to endpoint context or returns error on invalid token.
// Particularly useful for servers.
func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod, claims jwt.Claims) endpoint.Middleware {
return func(next endpoint.Endpoint) endpoint.Endpoint {
return func(ctx context.Context, request interface{}) (response interface{}, err error) {
// tokenString is stored in the context from the transport handlers.
Expand All @@ -88,7 +85,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middlewar
// of the token to identify which key to use, but the parsed token
// (head and claims) is provided to the callback, providing
// flexibility.
token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (interface{}, error) {
// Don't forget to validate the alg is what you expect:
if token.Method != method {
return nil, ErrUnexpectedSigningMethod
Expand Down Expand Up @@ -119,9 +116,7 @@ func NewParser(keyFunc jwt.Keyfunc, method jwt.SigningMethod) endpoint.Middlewar
return nil, ErrTokenInvalid
}

if claims, ok := token.Claims.(jwt.MapClaims); ok {
ctx = context.WithValue(ctx, JWTClaimsContextKey, Claims(claims))
}
ctx = context.WithValue(ctx, JWTClaimsContextKey, token.Claims)

return next(ctx, request)
}
Expand Down
96 changes: 77 additions & 19 deletions auth/jwt/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,38 @@ import (
"context"
"testing"

"crypto/subtle"

jwt "github.com/dgrijalva/jwt-go"
"github.com/go-kit/kit/endpoint"
)

type customClaims struct {
MyProperty string `json:"my_property"`
jwt.StandardClaims
}

func (c customClaims) VerifyMyProperty(p string) bool {
return subtle.ConstantTimeCompare([]byte(c.MyProperty), []byte(p)) != 0
}

var (
kid = "kid"
key = []byte("test_signing_key")
method = jwt.SigningMethodHS256
invalidMethod = jwt.SigningMethodRS256
claims = Claims{"user": "go-kit"}
kid = "kid"
key = []byte("test_signing_key")
myProperty = "some value"
method = jwt.SigningMethodHS256
invalidMethod = jwt.SigningMethodRS256
mapClaims = jwt.MapClaims{"user": "go-kit"}
standardClaims = jwt.StandardClaims{Audience: "go-kit"}
myCustomClaims = customClaims{MyProperty: myProperty, StandardClaims: standardClaims}
// Signed tokens generated at https://jwt.io/
signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA"
signedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJ1c2VyIjoiZ28ta2l0In0.14M2VmYyApdSlV_LZ88ajjwuaLeIFplB8JpyNy0A19E"
standardSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJnby1raXQifQ.L5ypIJjCOOv3jJ8G5SelaHvR04UJuxmcBN5QW3m_aoY"
customSignedKey = "eyJhbGciOiJIUzI1NiIsImtpZCI6ImtpZCIsInR5cCI6IkpXVCJ9.eyJteV9wcm9wZXJ0eSI6InNvbWUgdmFsdWUiLCJhdWQiOiJnby1raXQifQ.s8F-IDrV4WPJUsqr7qfDi-3GRlcKR0SRnkTeUT_U-i0"
invalidKey = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.e30.vKVCKto-Wn6rgz3vBdaZaCBGfCBDTXOENSo_X2Gq7qA"
)

func TestSigner(t *testing.T) {
e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }

signer := NewSigner(kid, key, method, claims)(e)
func signingValidator(t *testing.T, signer endpoint.Endpoint, expectedKey string) {
ctx, err := signer(context.Background(), struct{}{})
if err != nil {
t.Fatalf("Signer returned error: %s", err)
Expand All @@ -32,19 +46,32 @@ func TestSigner(t *testing.T) {
t.Fatal("Token did not exist in context")
}

if token != signedKey {
t.Fatalf("JWT tokens did not match: expecting %s got %s", signedKey, token)
if token != expectedKey {
t.Fatalf("JWT tokens did not match: expecting %s got %s", expectedKey, token)
}
}

func TestNewSigner(t *testing.T) {
e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }

signer := NewSigner(kid, key, method, mapClaims)(e)
signingValidator(t, signer, signedKey)

signer = NewSigner(kid, key, method, standardClaims)(e)
signingValidator(t, signer, standardSignedKey)

signer = NewSigner(kid, key, method, myCustomClaims)(e)
signingValidator(t, signer, customSignedKey)
}

func TestJWTParser(t *testing.T) {
e := func(ctx context.Context, i interface{}) (interface{}, error) { return ctx, nil }

keys := func(token *jwt.Token) (interface{}, error) {
return key, nil
}

parser := NewParser(keys, method)(e)
parser := NewParser(keys, method, jwt.MapClaims{})(e)

// No Token is passed into the parser
_, err := parser(context.Background(), struct{}{})
Expand All @@ -64,7 +91,7 @@ func TestJWTParser(t *testing.T) {
}

// Invalid Method is used in the parser
badParser := NewParser(keys, invalidMethod)(e)
badParser := NewParser(keys, invalidMethod, jwt.MapClaims{})(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -80,7 +107,7 @@ func TestJWTParser(t *testing.T) {
return []byte("bad"), nil
}

badParser = NewParser(invalidKeys, method)(e)
badParser = NewParser(invalidKeys, method, jwt.MapClaims{})(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, signedKey)
_, err = badParser(ctx, struct{}{})
if err == nil {
Expand All @@ -94,12 +121,43 @@ func TestJWTParser(t *testing.T) {
t.Fatalf("Parser returned error: %s", err)
}

cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(Claims)
cl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(jwt.MapClaims)
if !ok {
t.Fatal("Claims were not passed into context correctly")
}

if cl["user"] != mapClaims["user"] {
t.Fatalf("JWT Claims.user did not match: expecting %s got %s", mapClaims["user"], cl["user"])
}

parser = NewParser(keys, method, &jwt.StandardClaims{})(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, standardSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
t.Fatalf("Parser returned error: %s", err)
}
stdCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*jwt.StandardClaims)
if !ok {
t.Fatal("Claims were not passed into context correctly")
}
if !stdCl.VerifyAudience("go-kit", true) {
t.Fatalf("JWT jwt.StandardClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, stdCl.Audience)
}

if cl["user"] != claims["user"] {
t.Fatalf("JWT Claims.user did not match: expecting %s got %s", claims["user"], cl["user"])
parser = NewParser(keys, method, &customClaims{})(e)
ctx = context.WithValue(context.Background(), JWTTokenContextKey, customSignedKey)
ctx1, err = parser(ctx, struct{}{})
if err != nil {
t.Fatalf("Parser returned error: %s", err)
}
custCl, ok := ctx1.(context.Context).Value(JWTClaimsContextKey).(*customClaims)
if !ok {
t.Fatal("Claims were not passed into context correctly")
}
if !custCl.VerifyAudience("go-kit", true) {
t.Fatalf("JWT customClaims.Audience did not match: expecting %s got %s", standardClaims.Audience, custCl.Audience)
}
if !custCl.VerifyMyProperty(myProperty) {
t.Fatalf("JWT customClaims.MyProperty did not match: expecting %s got %s", myProperty, custCl.MyProperty)
}
}

0 comments on commit c9c7219

Please sign in to comment.