Skip to content

Commit

Permalink
Asymmetric algorithms (#2)
Browse files Browse the repository at this point in the history
Support for asymmetric algorithms (RSA and ECDSA families) for signing tokens
  • Loading branch information
kazarena authored Jul 12, 2017
1 parent 7eed1f4 commit 5f485d6
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 9 deletions.
68 changes: 59 additions & 9 deletions auth_jwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ import (
"strings"
"time"

"crypto/ecdsa"
"crypto/rsa"
"github.com/gin-gonic/gin"
"gopkg.in/dgrijalva/jwt-go.v3"
)
Expand All @@ -19,13 +21,22 @@ type GinJWTMiddleware struct {
// Realm name to display to the user. Required.
Realm string

// signing algorithm - possible values are HS256, HS384, HS512
// signing algorithm - possible values are:
//
// HS256, HS384, HS512
// RS256, RS384, RS512
// ES256, ES384, ES512
//
// Optional, default is HS256.
SigningAlgorithm string

// Secret key used for signing. Required.
// HMAC Secret key used for signing. Required for HSxxx algorithms
Key []byte

// Asymmetric keys used for signing. Required for RSxxx and ESxxx algorithms
SignKey interface{}
VerifyKey interface{}

// Duration that a jwt token is valid. Optional, defaults to one hour.
Timeout time.Duration

Expand Down Expand Up @@ -130,13 +141,49 @@ func (mw *GinJWTMiddleware) MiddlewareInit() error {
return errors.New("realm is required")
}

if mw.Key == nil {
return errors.New("secret key is required")
isSymmetricAlgo := mw.SigningAlgorithm == "HS256" || mw.SigningAlgorithm == "HS384" || mw.SigningAlgorithm == "HS512"

if isSymmetricAlgo {
if mw.Key == nil {
return errors.New("secret key is required")
}
// symmetrical algorithms use the same key for signing and verification of token
mw.SignKey = mw.Key
mw.VerifyKey = mw.Key
} else {
if isBadPrivateKey(mw.SignKey) {
return errors.New("private key is required")
}
if isBadPublicKey(mw.VerifyKey) {
return errors.New("public key is required")
}
}

return nil
}

func isBadPrivateKey(key interface{}) bool {
switch v := key.(type) {
case *rsa.PrivateKey:
return v == nil
case *ecdsa.PrivateKey:
return v == nil
default:
return true
}
}

func isBadPublicKey(key interface{}) bool {
switch v := key.(type) {
case *rsa.PublicKey:
return v == nil
case *ecdsa.PublicKey:
return v == nil
default:
return true
}
}

// MiddlewareFunc makes GinJWTMiddleware implement the Middleware interface.
func (mw *GinJWTMiddleware) MiddlewareFunc() gin.HandlerFunc {
if err := mw.MiddlewareInit(); err != nil {
Expand Down Expand Up @@ -180,7 +227,10 @@ func (mw *GinJWTMiddleware) middlewareImpl(c *gin.Context) {
func (mw *GinJWTMiddleware) LoginHandler(c *gin.Context) {

// Initial middleware default setting.
mw.MiddlewareInit()
if err := mw.MiddlewareInit(); err != nil {
mw.unauthorized(c, http.StatusInternalServerError, err.Error())
return
}

var loginVals Login

Expand Down Expand Up @@ -220,7 +270,7 @@ func (mw *GinJWTMiddleware) LoginHandler(c *gin.Context) {
claims["exp"] = expire.Unix()
claims["orig_iat"] = mw.TimeFunc().Unix()

tokenString, err := token.SignedString(mw.Key)
tokenString, err := token.SignedString(mw.SignKey)

if err != nil {
mw.unauthorized(c, http.StatusUnauthorized, "Create JWT Token faild")
Expand Down Expand Up @@ -260,7 +310,7 @@ func (mw *GinJWTMiddleware) RefreshHandler(c *gin.Context) {
newClaims["exp"] = expire.Unix()
newClaims["orig_iat"] = origIat

tokenString, err := newToken.SignedString(mw.Key)
tokenString, err := newToken.SignedString(mw.SignKey)

if err != nil {
mw.unauthorized(c, http.StatusUnauthorized, "Create JWT Token faild")
Expand Down Expand Up @@ -301,7 +351,7 @@ func (mw *GinJWTMiddleware) TokenGenerator(userID string) string {
claims["exp"] = mw.TimeFunc().Add(mw.Timeout).Unix()
claims["orig_iat"] = mw.TimeFunc().Unix()

tokenString, _ := token.SignedString(mw.Key)
tokenString, _ := token.SignedString(mw.SignKey)

return tokenString
}
Expand Down Expand Up @@ -364,7 +414,7 @@ func (mw *GinJWTMiddleware) parseToken(c *gin.Context) (*jwt.Token, error) {
return nil, errors.New("invalid signing algorithm")
}

return mw.Key, nil
return mw.VerifyKey, nil
})
}

Expand Down
128 changes: 128 additions & 0 deletions auth_jwt_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,32 @@ import (

var (
key = []byte("secret key")

rsaPrivateKeyString = `-----BEGIN RSA PRIVATE KEY-----
MIICXQIBAAKBgQDm5P69FhprYEz6BI6Dt0KaXheNG5LMiahMGsmW2/4ydwWQnB1t
Lf5OhRxJV8NV+k+e1HiP+ovzNWJ610hjDMhTtRahmgs0HAJ8kpQe4QCZAtHgbc6q
OIKK0c8+v0UGYqVrxJA0bASIhjTXOjPLvZqEU2p2IMacrjLecXKTW0/YEwIDAQAB
AoGBAKFI5pSIow3MaBjhI/foBHM2NLdRwnpz0gbPU2+43li8ATwhgQCp9xE8NCUb
VAxz3DgzbMAOIMJT0SXDygG+hRN4GCRX7xqtLt67t38Nr25Qgf8V+NPbLp4sHPFo
Fk2ODt5XxfE1Ca4tNYBSNPg8ozz+xjRPhuqT5lskXPVNrZ2xAkEA+Fmp0bDa1SSo
LAGg0YUee6NmMh+VoyuhSKNfkKGNSzPYz0PBFljtkYP0C16RHXBs/BdIc7tqSiIN
gFFer9IsmwJBAO4Br8MCjiGv8nXe8tx/IViJR0XM67SGHl8P9XSNa3p6Ih+F2nbG
rlPR2B4quVEFyKkRohUPkbs5ahrle/FqLekCQDHMIM4IDUkRyZrRVMLOU3dtIy/H
v4RxWiyrfZ0Nl7xNkBq3Nj9Z44D7GXMyKhziDyhZLtDt8nkc7OIe7sKIfSMCQQDF
pBTmZXrNsqQvCYK3Y8K3GNhcuDyLXkxeOIxlywITZNRtROQTeg1NgZZsBqJ5C8qD
yybDQniL9rOLvkFcSgXxAkA5GW4lJpmo2ZxDynVjfKkOlmwpAGXHWk4ta8vVzhEQ
blwQMKuzuVTPek5c2R3RXbSxxdivaFoIdbcYzWEPtqu4
-----END RSA PRIVATE KEY-----`

rsaPublicKeyString = `-----BEGIN PUBLIC KEY-----
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDm5P69FhprYEz6BI6Dt0KaXheN
G5LMiahMGsmW2/4ydwWQnB1tLf5OhRxJV8NV+k+e1HiP+ovzNWJ610hjDMhTtRah
mgs0HAJ8kpQe4QCZAtHgbc6qOIKK0c8+v0UGYqVrxJA0bASIhjTXOjPLvZqEU2p2
IMacrjLecXKTW0/YEwIDAQAB
-----END PUBLIC KEY-----`

rsaPrivateKey, _ = jwt.ParseRSAPrivateKeyFromPEM([]byte(rsaPrivateKeyString))
rsaPublicKey, _ = jwt.ParseRSAPublicKeyFromPEM([]byte(rsaPublicKeyString))
)

func makeTokenString(SigningAlgorithm string, username string) string {
Expand All @@ -33,6 +59,18 @@ func makeTokenString(SigningAlgorithm string, username string) string {
return tokenString
}

func makeAsymmetricTokenString(SigningAlgorithm string, username string) string {

token := jwt.New(jwt.GetSigningMethod(SigningAlgorithm))
claims := token.Claims.(jwt.MapClaims)
claims["id"] = username
claims["exp"] = time.Now().Add(time.Hour).Unix()
claims["orig_iat"] = time.Now().Unix()
tokenString, _ := token.SignedString(rsaPrivateKey)

return tokenString
}

func TestMissingRealm(t *testing.T) {

authMiddleware := &GinJWTMiddleware{
Expand Down Expand Up @@ -75,6 +113,51 @@ func TestMissingKey(t *testing.T) {
assert.Equal(t, "secret key is required", err.Error())
}

func TestMissingPrivateKey(t *testing.T) {

authMiddleware := &GinJWTMiddleware{
Realm: "test zone",
SigningAlgorithm: "RS256",
Timeout: time.Hour,
MaxRefresh: time.Hour * 24,
Authenticator: func(userId string, password string, c *gin.Context) (string, bool) {
if userId == "admin" && password == "admin" {
return "", true
}

return "", false
},
}

err := authMiddleware.MiddlewareInit()

assert.Error(t, err)
assert.Equal(t, "private key is required", err.Error())
}

func TestMissingPublicKey(t *testing.T) {

authMiddleware := &GinJWTMiddleware{
Realm: "test zone",
SigningAlgorithm: "RS256",
SignKey: rsaPrivateKey,
Timeout: time.Hour,
MaxRefresh: time.Hour * 24,
Authenticator: func(userId string, password string, c *gin.Context) (string, bool) {
if userId == "admin" && password == "admin" {
return "", true
}

return "", false
},
}

err := authMiddleware.MiddlewareInit()

assert.Error(t, err)
assert.Equal(t, "public key is required", err.Error())
}

func TestMissingTimeOut(t *testing.T) {

authMiddleware := &GinJWTMiddleware{
Expand Down Expand Up @@ -432,6 +515,51 @@ func TestAuthorizator(t *testing.T) {
})
}

func TestAuthorizatorRS256(t *testing.T) {
// the middleware to test
authMiddleware := &GinJWTMiddleware{
Realm: "test zone",
SigningAlgorithm: "RS256",
SignKey: rsaPrivateKey,
VerifyKey: rsaPublicKey,
Timeout: time.Hour,
MaxRefresh: time.Hour * 24,
Authenticator: func(userId string, password string, c *gin.Context) (string, bool) {
if userId == "admin" && password == "admin" {
return userId, true
}
return userId, false
},
Authorizator: func(userId string, c *gin.Context) bool {
if userId != "admin" {
return false
}

return true
},
}

handler := ginHandler(authMiddleware)

r := gofight.New()

r.GET("/auth/hello").
SetHeader(gofight.H{
"Authorization": "Bearer " + makeAsymmetricTokenString("RS256", "test"),
}).
Run(handler, func(r gofight.HTTPResponse, rq gofight.HTTPRequest) {
assert.Equal(t, http.StatusForbidden, r.Code)
})

r.GET("/auth/hello").
SetHeader(gofight.H{
"Authorization": "Bearer " + makeAsymmetricTokenString("RS256", "admin"),
}).
Run(handler, func(r gofight.HTTPResponse, rq gofight.HTTPRequest) {
assert.Equal(t, http.StatusOK, r.Code)
})
}

func TestClaimsDuringAuthorization(t *testing.T) {
// the middleware to test
authMiddleware := &GinJWTMiddleware{
Expand Down

0 comments on commit 5f485d6

Please sign in to comment.