From 90a025fab77095ac9167264df6c0f9b263c64b79 Mon Sep 17 00:00:00 2001 From: Christian Banse Date: Tue, 21 Feb 2023 12:42:55 +0100 Subject: [PATCH] Adding canonical `Keyfunc` functions for RSA, ECDSA, EdDSA and HMAC This PR adds ready-to-use keyfunc functions for the various signing methods. This should simplify a lot of standard use-cases and also includes a proper signing method check. --- ecdsa.go | 7 +++++ ed25519.go | 7 +++++ example_test.go | 23 ++++++++-------- hmac.go | 5 ++++ parser.go | 14 ++-------- rsa.go | 7 +++++ test/helpers.go | 3 ++- token.go | 34 +++++++++++++++++++++++ token_test.go | 71 ++++++++++++++++++++++++++++++++++++++++--------- 9 files changed, 133 insertions(+), 38 deletions(-) diff --git a/ecdsa.go b/ecdsa.go index 4ccae2a8..c5b3bcf6 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -132,3 +132,10 @@ func (m *SigningMethodECDSA) Sign(signingString string, key interface{}) ([]byte return nil, err } } + +// ECDSAPublicKey represents a [Keyfunc] that returns the ECDSA key specified in +// key. Furthermore, it checks, whether the signing method matches +// [SigningMethodECDSA]. +func ECDSAPublicKey(key *ecdsa.PublicKey) Keyfunc { + return secureKeyFunc(key, []string{"ES256", "ES384", "ES512"}) +} diff --git a/ed25519.go b/ed25519.go index 3db00e4a..2e7fff8e 100644 --- a/ed25519.go +++ b/ed25519.go @@ -78,3 +78,10 @@ func (m *SigningMethodEd25519) Sign(signingString string, key interface{}) ([]by return sig, nil } + +// Ed25519PublicKey represents a [Keyfunc] that returns the Ed25519 key +// specified in key. Furthermore, it checks, whether the signing method matches +// [SigningMethodEdDSA]. +func Ed25519PublicKey(key ed25519.PublicKey) Keyfunc { + return secureKeyFunc(key, []string{"EdDSA"}) +} diff --git a/example_test.go b/example_test.go index 58fdea43..176a4cdc 100644 --- a/example_test.go +++ b/example_test.go @@ -80,9 +80,7 @@ func ExampleParseWithClaims_customClaimsType() { jwt.RegisteredClaims } - token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte("AllYourBase"), nil - }) + token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, jwt.PresharedKey([]byte("AllYourBase"))) if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) @@ -103,9 +101,11 @@ func ExampleParseWithClaims_validationOptions() { jwt.RegisteredClaims } - token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte("AllYourBase"), nil - }, jwt.WithLeeway(5*time.Second)) + token, err := jwt.ParseWithClaims( + tokenString, &MyCustomClaims{}, + jwt.PresharedKey([]byte("AllYourBase")), + jwt.WithLeeway(5*time.Second), + ) if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) @@ -138,9 +138,10 @@ func (m MyCustomClaims) Validate() error { func ExampleParseWithClaims_customValidation() { tokenString := "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJpc3MiOiJ0ZXN0IiwiYXVkIjoic2luZ2xlIn0.QAWg1vGvnqRuCFTMcPkjZljXHh8U3L_qUjszOtQbeaA" - token, err := jwt.ParseWithClaims(tokenString, &MyCustomClaims{}, func(token *jwt.Token) (interface{}, error) { - return []byte("AllYourBase"), nil - }, jwt.WithLeeway(5*time.Second)) + token, err := jwt.ParseWithClaims( + tokenString, &MyCustomClaims{}, + jwt.PresharedKey([]byte("AllYourBase")), + jwt.WithLeeway(5*time.Second)) if claims, ok := token.Claims.(*MyCustomClaims); ok && token.Valid { fmt.Printf("%v %v", claims.Foo, claims.RegisteredClaims.Issuer) @@ -156,9 +157,7 @@ func ExampleParse_errorChecking() { // Token from another example. This token is expired var tokenString = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJmb28iOiJiYXIiLCJleHAiOjE1MDAwLCJpc3MiOiJ0ZXN0In0.HE7fK0xOQwFEr4WDgRWj4teRPZ6i3GLwD5YCm6Pwu_c" - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { - return []byte("AllYourBase"), nil - }) + token, err := jwt.Parse(tokenString, jwt.PresharedKey([]byte("AllYourBase"))) if token.Valid { fmt.Println("You look nice today") diff --git a/hmac.go b/hmac.go index 8609f4a8..4c0d6026 100644 --- a/hmac.go +++ b/hmac.go @@ -87,3 +87,8 @@ func (m *SigningMethodHMAC) Sign(signingString string, key interface{}) ([]byte, return nil, ErrInvalidKeyType } + +// PresharedKey represents a [Keyfunc] that simply returns the key specified in the byte slice. +func PresharedKey(key []byte) Keyfunc { + return secureKeyFunc(key, []string{"HS256", "HS384", "HS512"}) +} diff --git a/parser.go b/parser.go index f4386fba..df6108c3 100644 --- a/parser.go +++ b/parser.go @@ -4,7 +4,6 @@ import ( "bytes" "encoding/base64" "encoding/json" - "fmt" "strings" ) @@ -60,17 +59,8 @@ func (p *Parser) ParseWithClaims(tokenString string, claims Claims, keyFunc Keyf // Verify signing method is in the required set if p.validMethods != nil { - var signingMethodValid = false - var alg = token.Method.Alg() - for _, m := range p.validMethods { - if m == alg { - signingMethodValid = true - break - } - } - if !signingMethodValid { - // signing method is not in the listed set - return token, newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid) + if err = token.hasValidSigningMethod(p.validMethods); err != nil { + return token, err } } diff --git a/rsa.go b/rsa.go index daff0943..5a259e5f 100644 --- a/rsa.go +++ b/rsa.go @@ -91,3 +91,10 @@ func (m *SigningMethodRSA) Sign(signingString string, key interface{}) ([]byte, return nil, err } } + +// RSAPublicKey represents a [Keyfunc] that returns the RSA key specified in +// key. Furthermore, it checks, whether the signing method matches +// [SigningMethodRSA]. +func RSAPublicKey(key *rsa.PublicKey) Keyfunc { + return secureKeyFunc(key, []string{"RS256", "RS384", "RS512"}) +} diff --git a/test/helpers.go b/test/helpers.go index 381c5f8a..a27c1d56 100644 --- a/test/helpers.go +++ b/test/helpers.go @@ -2,6 +2,7 @@ package test import ( "crypto" + "crypto/ecdsa" "crypto/rsa" "os" @@ -56,7 +57,7 @@ func LoadECPrivateKeyFromDisk(location string) crypto.PrivateKey { return key } -func LoadECPublicKeyFromDisk(location string) crypto.PublicKey { +func LoadECPublicKeyFromDisk(location string) *ecdsa.PublicKey { keyData, e := os.ReadFile(location) if e != nil { panic(e.Error()) diff --git a/token.go b/token.go index 163c02f1..d9035d43 100644 --- a/token.go +++ b/token.go @@ -3,6 +3,7 @@ package jwt import ( "encoding/base64" "encoding/json" + "fmt" ) // Keyfunc will be used by the Parse methods as a callback function to supply @@ -81,3 +82,36 @@ func (t *Token) SigningString() (string, error) { func (*Token) EncodeSegment(seg []byte) string { return base64.RawURLEncoding.EncodeToString(seg) } + +// hasValidSigningMethod is a utility function that checks, if the signing +// method of the token is included in the validMethods slice. +func (token *Token) hasValidSigningMethod(validMethods []string) error { + var signingMethodValid = false + var alg = token.Method.Alg() + for _, m := range validMethods { + if m == alg { + signingMethodValid = true + break + } + } + + if !signingMethodValid { + // signing method is not in the listed set + return newError(fmt.Sprintf("signing method %v is invalid", alg), ErrTokenSignatureInvalid) + } + + return nil +} + +// secureKeyFunc returns a secure [Keyfunc] for the specified key that also +// includes a signing method check. +func secureKeyFunc(key any, validMethods []string) Keyfunc { + return func(t *Token) (interface{}, error) { + // Check, if the signing method matches + if err := t.hasValidSigningMethod(validMethods); err != nil { + return nil, err + } + + return key, nil + } +} diff --git a/token_test.go b/token_test.go index 95709ade..94296443 100644 --- a/token_test.go +++ b/token_test.go @@ -1,17 +1,17 @@ -package jwt_test +package jwt import ( + "errors" + "reflect" "testing" - - "github.com/golang-jwt/jwt/v5" ) func TestToken_SigningString(t1 *testing.T) { type fields struct { Raw string - Method jwt.SigningMethod + Method SigningMethod Header map[string]interface{} - Claims jwt.Claims + Claims Claims Signature []byte Valid bool } @@ -25,12 +25,12 @@ func TestToken_SigningString(t1 *testing.T) { name: "", fields: fields{ Raw: "", - Method: jwt.SigningMethodHS256, + Method: SigningMethodHS256, Header: map[string]interface{}{ "typ": "JWT", - "alg": jwt.SigningMethodHS256.Alg(), + "alg": SigningMethodHS256.Alg(), }, - Claims: jwt.RegisteredClaims{}, + Claims: RegisteredClaims{}, Valid: false, }, want: "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.e30", @@ -39,7 +39,7 @@ func TestToken_SigningString(t1 *testing.T) { } for _, tt := range tests { t1.Run(tt.name, func(t1 *testing.T) { - t := &jwt.Token{ + t := &Token{ Raw: tt.fields.Raw, Method: tt.fields.Method, Header: tt.fields.Header, @@ -60,13 +60,13 @@ func TestToken_SigningString(t1 *testing.T) { } func BenchmarkToken_SigningString(b *testing.B) { - t := &jwt.Token{ - Method: jwt.SigningMethodHS256, + t := &Token{ + Method: SigningMethodHS256, Header: map[string]interface{}{ "typ": "JWT", - "alg": jwt.SigningMethodHS256.Alg(), + "alg": SigningMethodHS256.Alg(), }, - Claims: jwt.RegisteredClaims{}, + Claims: RegisteredClaims{}, } b.Run("BenchmarkToken_SigningString", func(b *testing.B) { b.ResetTimer() @@ -76,3 +76,48 @@ func BenchmarkToken_SigningString(b *testing.B) { } }) } + +func Test_secureKeyFunc(t *testing.T) { + type fields struct { + token *Token + } + type args struct { + key any + validMethods []string + } + tests := []struct { + name string + fields fields + args args + wantKey any + wantErr error + }{ + { + name: "invalid method", + fields: fields{&Token{Header: map[string]interface{}{"alg": "RS512"}, Method: SigningMethodRS512}}, + args: args{key: []byte("mysecret"), validMethods: []string{"HS256"}}, + wantKey: nil, + wantErr: ErrTokenSignatureInvalid, + }, + { + name: "correct method", + fields: fields{&Token{Header: map[string]interface{}{"alg": "HS256"}, Method: SigningMethodHS256}}, + args: args{key: []byte("mysecret"), validMethods: []string{"HS256"}}, + wantKey: []byte("mysecret"), + wantErr: nil, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + keyfunc := secureKeyFunc(tt.args.key, tt.args.validMethods) + gotKey, gotErr := keyfunc(tt.fields.token) + + if !reflect.DeepEqual(gotKey, tt.wantKey) { + t.Errorf("secureKeyFunc() key = %v, want %v", gotKey, tt.wantKey) + } + if (gotErr != nil) && !errors.Is(gotErr, tt.wantErr) { + t.Errorf("secureKeyFunc() err = %v, want %v", gotErr, tt.wantErr) + } + }) + } +}