diff --git a/Changes b/Changes index 60379c80d..221d2b3c8 100644 --- a/Changes +++ b/Changes @@ -25,6 +25,22 @@ v2.0.16 UNRELEASED fail. Therefore this in itself should not pose any security risk, albeit allowing some illegally formated messages to be verified. + * [jwk] `jwk.Key` objects now have a `Validate()` method to validate the data + stored in the keys. However, this still does not necessarily mean that the key's + are valid for use in cryptographic operations. If `Validate()` is successful, + it only means that the keys are in the right _format_, including the presence + of required fields and that certain fields are have proper length, etc. + +[New Features] + * [jws] Added `jws.WithValidateKey()` to force calling `key.Validate()` before + signing or verification. + + * [jws] `jws.Sign()` now returns a special type of error that can hold the + individual errors from the signers. The stringification is still the same + as before to preserve backwards compatibility. + + * [jwk] Added `jwk.IsKeyValidationError` that checks if an error is an error + from `key.Validate()`. v2.0.15 19 20 Oct 2023 [Bug fixes] diff --git a/jwk/ecdsa.go b/jwk/ecdsa.go index 67a14ba63..dcbe2ca6c 100644 --- a/jwk/ecdsa.go +++ b/jwk/ecdsa.go @@ -226,3 +226,48 @@ func (k ecdsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) { base64.EncodeToString(ybuf), ), nil } + +func ecdsaValidateKey(k interface { + Crv() jwa.EllipticCurveAlgorithm + X() []byte + Y() []byte +}, checkPrivate bool) error { + crv, ok := ecutil.CurveForAlgorithm(k.Crv()) + if !ok { + return fmt.Errorf(`invalid curve algorithm %q`, k.Crv()) + } + + keySize := ecutil.CalculateKeySize(crv) + if x := k.X(); len(x) != keySize { + return fmt.Errorf(`invalid "x" length (%d) for curve %q`, len(x), crv.Params().Name) + } + + if y := k.Y(); len(y) != keySize { + return fmt.Errorf(`invalid "y" length (%d) for curve %q`, len(y), crv.Params().Name) + } + + if checkPrivate { + if priv, ok := k.(interface{ D() []byte }); ok { + if len(priv.D()) != keySize { + return fmt.Errorf(`invalid "d" length (%d) for curve %q`, len(priv.D()), crv.Params().Name) + } + } else { + return fmt.Errorf(`missing "d" value`) + } + } + return nil +} + +func (k *ecdsaPrivateKey) Validate() error { + if err := ecdsaValidateKey(k, true); err != nil { + return NewKeyValidationError(fmt.Errorf(`jwk.ECDSAPrivateKey: %w`, err)) + } + return nil +} + +func (k *ecdsaPublicKey) Validate() error { + if err := ecdsaValidateKey(k, false); err != nil { + return NewKeyValidationError(fmt.Errorf(`jwk.ECDSAPublicKey: %w`, err)) + } + return nil +} diff --git a/jwk/interface_gen.go b/jwk/interface_gen.go index 6e4e79a04..9ee50516c 100644 --- a/jwk/interface_gen.go +++ b/jwk/interface_gen.go @@ -46,6 +46,20 @@ type Key interface { // Remove removes the field associated with the specified key. // There is no way to remove the `kty` (key type). You will ALWAYS be left with one field in a jwk.Key. Remove(string) error + // Validate performs _minimal_ checks if the data stored in the key are valid. + // By minimal, we mean that it does not check if the key is valid for use in + // cryptographic operations. For example, it does not check if an RSA key's + // `e` field is a valid exponent, or if the `n` field is a valid modulus. + // Instead, it checks for things such as the _presence_ of some required fields, + // or if certain keys' values are of particular length. + // + // Note that depending on th underlying key type, use of this method requires + // that multiple fields in the key are properly populated. For example, an EC + // key's "x", "y" fields cannot be validated unless the "crv" field is populated first. + // + // Validate is never called by `UnmarshalJSON()` or `Set`. It must explicitly be + // called by the user + Validate() error // Raw creates the corresponding raw key. For example, // EC types would create *ecdsa.PublicKey or *ecdsa.PrivateKey, diff --git a/jwk/jwk.go b/jwk/jwk.go index 9c5cbb04c..bf129e8c6 100644 --- a/jwk/jwk.go +++ b/jwk/jwk.go @@ -12,6 +12,7 @@ import ( "crypto/rsa" "crypto/x509" "encoding/pem" + "errors" "fmt" "io" "math/big" @@ -757,3 +758,32 @@ func IsPrivateKey(k Key) (bool, error) { } return false, fmt.Errorf("jwk.IsPrivateKey: %T is not an asymmetric key", k) } + +type keyValidationError struct { + err error +} + +func (e *keyValidationError) Error() string { + return fmt.Sprintf(`key validation failed: %s`, e.err) +} + +func (e *keyValidationError) Unwrap() error { + return e.err +} + +func (e *keyValidationError) Is(target error) bool { + _, ok := target.(*keyValidationError) + return ok +} + +// NewKeyValidationError wraps the given error with an error that denotes +// `key.Validate()` has failed. This error type should ONLY be used as +// return value from the `Validate()` method. +func NewKeyValidationError(err error) error { + return &keyValidationError{err: err} +} + +func IsKeyValidationError(err error) bool { + var kve keyValidationError + return errors.Is(err, &kve) +} diff --git a/jwk/jwk_internal_test.go b/jwk/jwk_internal_test.go index b3085e0a6..39b4fef41 100644 --- a/jwk/jwk_internal_test.go +++ b/jwk/jwk_internal_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/lestrrat-go/jwx/v2/cert" + "github.com/lestrrat-go/jwx/v2/internal/base64" "github.com/lestrrat-go/jwx/v2/internal/json" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/stretchr/testify/assert" @@ -129,9 +130,18 @@ func TestIterator(t *testing.T) { { Extras: map[string]interface{}{ ECDSACrvKey: jwa.P256, - ECDSAXKey: []byte("MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4"), - ECDSAYKey: []byte("4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM"), - ECDSADKey: []byte("870MB6gfuTJ4HtUnUvYMyJpr5eUZNP4Bk43bVdj3eAE"), + ECDSAXKey: (func() []byte { + s, _ := base64.DecodeString("MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4") + return s + })(), + ECDSAYKey: (func() []byte { + s, _ := base64.DecodeString("4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM") + return s + })(), + ECDSADKey: (func() []byte { + s, _ := base64.DecodeString("870MB6gfuTJ4HtUnUvYMyJpr5eUZNP4Bk43bVdj3eAE") + return s + })(), }, Func: func() Key { return newECDSAPrivateKey() @@ -140,8 +150,14 @@ func TestIterator(t *testing.T) { { Extras: map[string]interface{}{ ECDSACrvKey: jwa.P256, - ECDSAXKey: []byte("MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4"), - ECDSAYKey: []byte("4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM"), + ECDSAXKey: (func() []byte { + s, _ := base64.DecodeString("MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4") + return s + })(), + ECDSAYKey: (func() []byte { + s, _ := base64.DecodeString("4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM") + return s + })(), }, Func: func() Key { return newECDSAPublicKey() @@ -184,6 +200,7 @@ func TestIterator(t *testing.T) { } if !assert.NoError(t, json.Unmarshal(buf, key2), `json.Unmarshal should succeed`) { + t.Logf("%s", buf) return } diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go index fd2df83ae..9ed60d370 100644 --- a/jwk/jwk_test.go +++ b/jwk/jwk_test.go @@ -2217,3 +2217,54 @@ func TestGH947(t *testing.T) { var exported []byte require.Error(t, k.Raw(&exported), `(okpkey).Raw with 0-length OKP key should fail`) } + +func TestValidation(t *testing.T) { + { + key, err := jwxtest.GenerateRsaJwk() + require.NoError(t, err, `jwx.GenerateRsaJwk should succeed`) + require.NoError(t, key.Validate(), `key.Validate should succeed (vanilla key)`) + + require.NoError(t, key.Set(jwk.RSADKey, []byte(nil)), `key.Set should succeed`) + require.Error(t, key.Validate(), `key.Validate should fail`) + } + + { + key, err := jwxtest.GenerateEcdsaJwk() + require.NoError(t, err, `jwx.GenerateEcdsaJwk should succeed`) + require.NoError(t, key.Validate(), `key.Validate should succeed`) + + x := key.(jwk.ECDSAPrivateKey).X() + require.NoError(t, key.Set(jwk.ECDSAXKey, x[:len(x)/2]), `key.Set should succeed`) + require.Error(t, key.Validate(), `key.Validate should fail`) + + require.NoError(t, key.Set(jwk.ECDSAXKey, x), `key.Set should succeed`) + require.NoError(t, key.Validate(), `key.Validate should succeed`) + + require.NoError(t, key.Set(jwk.ECDSADKey, []byte(nil)), `key.Set should succeed`) + require.Error(t, key.Validate(), `key.Validate should fail`) + } + + { + key, err := jwxtest.GenerateEd25519Jwk() + require.NoError(t, err, `jwx.GenerateEd25519Jwk should succeed`) + require.NoError(t, key.Validate(), `key.Validate should succeed`) + x := key.(jwk.OKPPrivateKey).X() + require.NoError(t, key.Set(jwk.OKPXKey, []byte(nil)), `key.Set should succeed`) + require.Error(t, key.Validate(), `key.Validate should fail`) + + require.NoError(t, key.Set(jwk.OKPXKey, x), `key.Set should succeed`) + require.NoError(t, key.Validate(), `key.Validate should succeed`) + + require.NoError(t, key.Set(jwk.OKPDKey, []byte(nil)), `key.Set should succeed`) + require.Error(t, key.Validate(), `key.Validate should fail`) + } + + { + key, err := jwxtest.GenerateSymmetricJwk() + require.NoError(t, err, `jwx.GenerateSymmetricJwk should succeed`) + require.NoError(t, key.Validate(), `key.Validate should succeed`) + + require.NoError(t, key.Set(jwk.SymmetricOctetsKey, []byte(nil)), `key.Set should succeed`) + require.Error(t, key.Validate(), `key.Validate should fail`) + } +} diff --git a/jwk/okp.go b/jwk/okp.go index 9c01ec165..7754bf2a0 100644 --- a/jwk/okp.go +++ b/jwk/okp.go @@ -187,3 +187,41 @@ func (k okpPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) { base64.EncodeToString(k.x), ), nil } + +func validateOKPKey(key interface { + Crv() jwa.EllipticCurveAlgorithm + X() []byte +}) error { + if key.Crv() == jwa.InvalidEllipticCurve { + return fmt.Errorf(`invalid curve algorithm`) + } + + if len(key.X()) == 0 { + return fmt.Errorf(`missing "x" field`) + } + + if priv, ok := key.(interface{ D() []byte }); ok { + if len(priv.D()) == 0 { + return fmt.Errorf(`missing "d" field`) + } + } + return nil +} + +func (k *okpPublicKey) Validate() error { + k.mu.RLock() + defer k.mu.RUnlock() + if err := validateOKPKey(k); err != nil { + return NewKeyValidationError(fmt.Errorf(`jwk.OKPPublicKey: %w`, err)) + } + return nil +} + +func (k *okpPrivateKey) Validate() error { + k.mu.RLock() + defer k.mu.RUnlock() + if err := validateOKPKey(k); err != nil { + return NewKeyValidationError(fmt.Errorf(`jwk.OKPPrivateKey: %w`, err)) + } + return nil +} diff --git a/jwk/rsa.go b/jwk/rsa.go index 5de6b6358..89c1a534e 100644 --- a/jwk/rsa.go +++ b/jwk/rsa.go @@ -241,3 +241,43 @@ func rsaThumbprint(hash crypto.Hash, key *rsa.PublicKey) ([]byte, error) { } return h.Sum(nil), nil } + +func validateRSAKey(key interface { + N() []byte + E() []byte +}, checkPrivate bool) error { + if len(key.N()) == 0 { + // Ideally we would like to check for the actual length, but unlike + // EC keys, we have nothing in the key itself that will tell us + // how many bits this key should have. + return fmt.Errorf(`missing "n" value`) + } + if len(key.E()) == 0 { + return fmt.Errorf(`missing "e" value`) + } + if checkPrivate { + if priv, ok := key.(interface{ D() []byte }); ok { + if len(priv.D()) == 0 { + return fmt.Errorf(`missing "d" value`) + } + } else { + return fmt.Errorf(`missing "d" value`) + } + } + + return nil +} + +func (k *rsaPrivateKey) Validate() error { + if err := validateRSAKey(k, true); err != nil { + return NewKeyValidationError(fmt.Errorf(`jwk.RSAPrivateKey: %w`, err)) + } + return nil +} + +func (k *rsaPublicKey) Validate() error { + if err := validateRSAKey(k, false); err != nil { + return NewKeyValidationError(fmt.Errorf(`jwk.RSAPublicKey: %w`, err)) + } + return nil +} diff --git a/jwk/symmetric.go b/jwk/symmetric.go index d2498e334..378f9c738 100644 --- a/jwk/symmetric.go +++ b/jwk/symmetric.go @@ -58,3 +58,10 @@ func (k *symmetricKey) PublicKey() (Key, error) { } return newKey, nil } + +func (k *symmetricKey) Validate() error { + if len(k.Octets()) == 0 { + return NewKeyValidationError(fmt.Errorf(`jwk.SymmetricKey: missing "k" field`)) + } + return nil +} diff --git a/jws/jws.go b/jws/jws.go index 7e9eecefb..a348c6186 100644 --- a/jws/jws.go +++ b/jws/jws.go @@ -28,6 +28,7 @@ import ( "crypto/ecdsa" "crypto/ed25519" "crypto/rsa" + "errors" "fmt" "io" "reflect" @@ -107,6 +108,18 @@ const ( var _ = fmtInvalid var _ = fmtMax +func validateKeyBeforeUse(key interface{}) error { + jwkKey, ok := key.(jwk.Key) + if !ok { + converted, err := jwk.FromRaw(key) + if err != nil { + return fmt.Errorf(`could not convert key of type %T to jwk.Key for validation: %w`, key, err) + } + jwkKey = converted + } + return jwkKey.Validate() +} + // Sign generates a JWS message for the given payload and returns // it in serialized form, which can be in either compact or // JSON format. Default is compact. @@ -149,6 +162,7 @@ func Sign(payload []byte, options ...SignOption) ([]byte, error) { var signers []*payloadSigner var detached bool var noneSignature *payloadSigner + var validateKey bool for _, option := range options { //nolint:forcetypeassert switch option.Ident() { @@ -185,6 +199,8 @@ func Sign(payload []byte, options ...SignOption) ([]byte, error) { return nil, fmt.Errorf(`jws.Sign: payload must be nil when jws.WithDetachedPayload() is specified`) } payload = option.Value().([]byte) + case identValidateKey{}: + validateKey = option.Value().(bool) } } @@ -239,6 +255,12 @@ func Sign(payload []byte, options ...SignOption) ([]byte, error) { // cheat. FIXXXXXXMEEEEEE detached: detached, } + + if validateKey { + if err := validateKeyBeforeUse(signer.key); err != nil { + return nil, fmt.Errorf(`jws.Verify: %w`, err) + } + } _, _, err := sig.Sign(payload, signer.signer, signer.key) if err != nil { return nil, fmt.Errorf(`failed to generate signature for signer #%d (alg=%s): %w`, i, signer.Algorithm(), err) @@ -290,6 +312,7 @@ func Verify(buf []byte, options ...VerifyOption) ([]byte, error) { var detachedPayload []byte var keyProviders []KeyProvider var keyUsed interface{} + var validateKey bool ctx := context.Background() @@ -316,6 +339,8 @@ func Verify(buf []byte, options ...VerifyOption) ([]byte, error) { keyUsed = option.Value() case identContext{}: ctx = option.Value().(context.Context) + case identValidateKey{}: + validateKey = option.Value().(bool) default: return nil, fmt.Errorf(`invalid jws.VerifyOption %q passed`, `With`+strings.TrimPrefix(fmt.Sprintf(`%T`, option.Ident()), `jws.ident`)) } @@ -350,6 +375,7 @@ func Verify(buf []byte, options ...VerifyOption) ([]byte, error) { verifyBuf := pool.GetBytesBuffer() defer pool.ReleaseBytesBuffer(verifyBuf) + var errs []error for i, sig := range msg.signatures { verifyBuf.Reset() @@ -386,12 +412,19 @@ func Verify(buf []byte, options ...VerifyOption) ([]byte, error) { //nolint:forcetypeassert alg := pair.alg.(jwa.SignatureAlgorithm) key := pair.key + + if validateKey { + if err := validateKeyBeforeUse(key); err != nil { + return nil, fmt.Errorf(`jws.Verify: %w`, err) + } + } verifier, err := NewVerifier(alg) if err != nil { return nil, fmt.Errorf(`failed to create verifier for algorithm %q: %w`, alg, err) } if err := verifier.Verify(verifyBuf.Bytes(), sig.signature, key); err != nil { + errs = append(errs, err) continue } @@ -409,7 +442,33 @@ func Verify(buf []byte, options ...VerifyOption) ([]byte, error) { } } } - return nil, fmt.Errorf(`could not verify message using any of the signatures or keys`) + return nil, &verifyError{errs: errs} +} + +type verifyError struct { + // Note: when/if we can ditch Go < 1.20, we can change this to a simple + // `err error`, where the value is the result of `errors.Join()` + // + // We also need to implement Unwrap: + // func (e *verifyError) Unwrap() error { + // return e.err + //} + // + // And finally, As() can go away + errs []error +} + +func (e *verifyError) Error() string { + return `could not verify message using any of the signatures or keys` +} + +func (e *verifyError) As(target interface{}) bool { + for _, err := range e.errs { + if errors.As(err, target) { + return true + } + } + return false } // get the value of b64 header field. diff --git a/jws/jws_test.go b/jws/jws_test.go index 3960e9e00..bf29b5a8b 100644 --- a/jws/jws_test.go +++ b/jws/jws_test.go @@ -1129,7 +1129,8 @@ func TestVerifyNonUniqueKid(t *testing.T) { Name: `match 2 keys via same "kid" and different key type / alg`, Key: func() jwk.Key { privateKey, _ := jwxtest.GenerateEcdsaKey(jwa.P256) - wrongKey, _ := jwk.PublicKeyOf(privateKey) + wrongKey, err := jwk.PublicKeyOf(privateKey) + require.NoError(t, err, `jwk.PublicKeyOf should succeed`) _ = wrongKey.Set(jwk.KeyIDKey, kid) _ = wrongKey.Set(jwk.AlgorithmKey, jwa.ES256K) return wrongKey @@ -2120,3 +2121,36 @@ func TestUnpaddedSignatureR(t *testing.T) { _, err = jws.Verify([]byte(unpadded), jws.WithKey(jwa.ES256, pubKey)) require.Error(t, err, `jws.Verify should fail`) } + +func TestValidateKey(t *testing.T) { + privKey, err := jwxtest.GenerateRsaJwk() + require.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) + + signed, err := jws.Sign([]byte("Lorem Ipsum"), jws.WithKey(jwa.RS256, privKey), jws.WithValidateKey(true)) + require.NoError(t, err, `jws.Sign should succeed`) + + // This should fail because D is empty + require.NoError(t, privKey.Set(jwk.RSADKey, []byte(nil)), `jwk.Set should succeed`) + _, err = jws.Sign([]byte("Lorem Ipsum"), jws.WithKey(jwa.RS256, privKey), jws.WithValidateKey(true)) + require.Error(t, err, `jws.Sign should fail`) + + pubKey, err := jwk.PublicKeyOf(privKey) + require.NoError(t, err, `jwk.PublicKeyOf should succeed`) + + n := pubKey.(jwk.RSAPublicKey).N() + + // Set N to an empty value + require.NoError(t, pubKey.Set(jwk.RSANKey, []byte(nil)), `jwk.Set should succeed`) + + // This is going to fail regardless, because the public key is now + // invalid (empty N), but we want to make sure that it fails because + // of the validation failing + _, err = jws.Verify(signed, jws.WithKey(jwa.RS256, pubKey), jws.WithValidateKey(true)) + require.Error(t, err, `jws.Verify should fail`) + require.True(t, jwk.IsKeyValidationError(err), `jwk.IsKeyValidationError should return true`) + + // The following should now succeed, because N has been reinstated + require.NoError(t, pubKey.Set(jwk.RSANKey, n), `jwk.Set should succeed`) + _, err = jws.Verify(signed, jws.WithKey(jwa.RS256, pubKey), jws.WithValidateKey(true)) + require.NoError(t, err, `jws.Verify should succeed`) +} diff --git a/jws/options.yaml b/jws/options.yaml index f07d66d48..65b5352c1 100644 --- a/jws/options.yaml +++ b/jws/options.yaml @@ -88,6 +88,27 @@ options: `jwk.Key` here unless you are 100% sure that all keys that you have provided are instances of `jwk.Key` (remember that the jwx API allows users to specify a raw key such as *rsa.PublicKey) + - ident: ValidateKey + interface: SignVerifyOption + argument_type: bool + comment: | + WithValidateKey specifies whether the key used for signing or verification + should be validated before using. Note that this means calling + `key.Validate()` on the key, which in turn means that your key + must be a `jwk.Key` instance, or a key that can be converted to + a `jwk.Key` by calling `jwk.FromRaw()`. This means that your + custom hardware-backed keys will probably not work. + + You can directly call `key.Validate()` yourself if you need to + mix keys that cannot be converted to `jwk.Key`. + + Please also note that use of this option will also result in + one extra conversion of raw keys to a `jwk.Key` instance. If you + care about shaving off as much as possible, consider using a + pre-validated key instead of using this option to validate + the key on-demand each time. + + By default, the key is not validated. - ident: InferAlgorithmFromKey interface: WithKeySetSuboption argument_type: bool diff --git a/jws/options_gen.go b/jws/options_gen.go index 17f993927..ca834e103 100644 --- a/jws/options_gen.go +++ b/jws/options_gen.go @@ -139,6 +139,7 @@ type identPublicHeaders struct{} type identRequireKid struct{} type identSerialization struct{} type identUseDefault struct{} +type identValidateKey struct{} func (identContext) String() string { return "WithContext" @@ -204,6 +205,10 @@ func (identUseDefault) String() string { return "WithUseDefault" } +func (identValidateKey) String() string { + return "WithValidateKey" +} + func WithContext(v context.Context) VerifyOption { return &verifyOption{option.New(identContext{}, v)} } @@ -334,3 +339,24 @@ func WithCompact() SignOption { func WithUseDefault(v bool) WithKeySetSuboption { return &withKeySetSuboption{option.New(identUseDefault{}, v)} } + +// WithValidateKey specifies whether the key used for signing or verification +// should be validated before using. Note that this means calling +// `key.Validate()` on the key, which in turn means that your key +// must be a `jwk.Key` instance, or a key that can be converted to +// a `jwk.Key` by calling `jwk.FromRaw()`. This means that your +// custom hardware-backed keys will probably not work. +// +// You can directly call `key.Validate()` yourself if you need to +// mix keys that cannot be converted to `jwk.Key`. +// +// Please also note that use of this option will also result in +// one extra conversion of raw keys to a `jwk.Key` instance. If you +// care about shaving off as much as possible, consider using a +// pre-validated key instead of using this option to validate +// the key on-demand each time. +// +// By default, the key is not validated. +func WithValidateKey(v bool) SignVerifyOption { + return &signVerifyOption{option.New(identValidateKey{}, v)} +} diff --git a/jws/options_gen_test.go b/jws/options_gen_test.go index f3ba4b32f..75631281e 100644 --- a/jws/options_gen_test.go +++ b/jws/options_gen_test.go @@ -25,4 +25,5 @@ func TestOptionIdent(t *testing.T) { require.Equal(t, "WithRequireKid", identRequireKid{}.String()) require.Equal(t, "WithSerialization", identSerialization{}.String()) require.Equal(t, "WithUseDefault", identUseDefault{}.String()) + require.Equal(t, "WithValidateKey", identValidateKey{}.String()) } diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go index 6c0f04240..9b8554e2a 100644 --- a/jwt/jwt_test.go +++ b/jwt/jwt_test.go @@ -1738,7 +1738,7 @@ func TestGH951(t *testing.T) { decrypted, err := jwe.Decrypt(serialized, jwe.WithKey(jwa.A128KW, sharedKey)) require.NoError(t, err, `jwe.Decrypt should succeed`) - verified, err := jwt.Parse(decrypted, jwt.WithKey(jwa.RS256, signKey)) + verified, err := jwt.Parse(decrypted, jwt.WithKey(jwa.RS256, signKey.PublicKey)) require.NoError(t, err, `jwt.Parse should succeed`) require.True(t, jwt.Equal(verified, token), `tokens should be equal`) diff --git a/jwx_test.go b/jwx_test.go index 00f8bd25c..40011ff7b 100644 --- a/jwx_test.go +++ b/jwx_test.go @@ -562,35 +562,43 @@ func TestGH996(t *testing.T) { symmetricKey := []byte(`abracadabra`) testcases := []struct { - Name string - Algorithm jwa.SignatureAlgorithm - Valid []interface{} - Invalid []interface{} + Name string + Algorithm jwa.SignatureAlgorithm + ValidSigningKeys []interface{} + InvalidSigningKeys []interface{} + ValidVerificationKeys []interface{} + InvalidVerificationKeys []interface{} }{ { - Name: `ECDSA`, - Algorithm: jwa.ES256, - Valid: []interface{}{ecdsaKey}, - Invalid: []interface{}{rsaKey, okpKey, symmetricKey}, + Name: `ECDSA`, + Algorithm: jwa.ES256, + ValidSigningKeys: []interface{}{ecdsaKey}, + InvalidSigningKeys: []interface{}{rsaKey, okpKey, symmetricKey}, + ValidVerificationKeys: []interface{}{ecdsaKey.PublicKey}, + InvalidVerificationKeys: []interface{}{rsaKey.PublicKey, okpKey.Public(), symmetricKey}, }, { - Name: `RSA`, - Algorithm: jwa.RS256, - Valid: []interface{}{rsaKey}, - Invalid: []interface{}{ecdsaKey, okpKey, symmetricKey}, + Name: `RSA`, + Algorithm: jwa.RS256, + ValidSigningKeys: []interface{}{rsaKey}, + InvalidSigningKeys: []interface{}{ecdsaKey, okpKey, symmetricKey}, + ValidVerificationKeys: []interface{}{rsaKey.PublicKey}, + InvalidVerificationKeys: []interface{}{ecdsaKey.PublicKey, okpKey.Public(), symmetricKey}, }, { - Name: `OKP`, - Algorithm: jwa.EdDSA, - Valid: []interface{}{okpKey}, - Invalid: []interface{}{ecdsaKey, rsaKey, symmetricKey}, + Name: `OKP`, + Algorithm: jwa.EdDSA, + ValidSigningKeys: []interface{}{okpKey}, + InvalidSigningKeys: []interface{}{ecdsaKey, rsaKey, symmetricKey}, + ValidVerificationKeys: []interface{}{okpKey.Public()}, + InvalidVerificationKeys: []interface{}{ecdsaKey.PublicKey, rsaKey.PublicKey, symmetricKey}, }, } for _, tc := range testcases { tc := tc t.Run(tc.Name, func(t *testing.T) { - for _, valid := range tc.Valid { + for _, valid := range tc.ValidSigningKeys { valid := valid t.Run(fmt.Sprintf("Sign Valid(%T)", valid), func(t *testing.T) { _, err := jws.Sign([]byte("Lorem Ipsum"), jws.WithKey(tc.Algorithm, valid)) @@ -598,7 +606,7 @@ func TestGH996(t *testing.T) { }) } - for _, invalid := range tc.Invalid { + for _, invalid := range tc.InvalidSigningKeys { invalid := invalid t.Run(fmt.Sprintf("Sign Invalid(%T)", invalid), func(t *testing.T) { _, err := jws.Sign([]byte("Lorem Ipsum"), jws.WithKey(tc.Algorithm, invalid)) @@ -606,10 +614,10 @@ func TestGH996(t *testing.T) { }) } - signed, err := jws.Sign([]byte("Lorem Ipsum"), jws.WithKey(tc.Algorithm, tc.Valid[0])) + signed, err := jws.Sign([]byte("Lorem Ipsum"), jws.WithKey(tc.Algorithm, tc.ValidSigningKeys[0])) require.NoError(t, err, `jws.Sign with valid key should succeed`) - for _, valid := range tc.Valid { + for _, valid := range tc.ValidVerificationKeys { valid := valid t.Run(fmt.Sprintf("Verify Valid(%T)", valid), func(t *testing.T) { _, err := jws.Verify(signed, jws.WithKey(tc.Algorithm, valid)) @@ -617,7 +625,7 @@ func TestGH996(t *testing.T) { }) } - for _, invalid := range tc.Invalid { + for _, invalid := range tc.InvalidVerificationKeys { invalid := invalid t.Run(fmt.Sprintf("Verify Invalid(%T)", invalid), func(t *testing.T) { _, err := jws.Verify(signed, jws.WithKey(tc.Algorithm, invalid)) diff --git a/tools/cmd/genjwk/main.go b/tools/cmd/genjwk/main.go index fd2c6a199..0ddad9117 100644 --- a/tools/cmd/genjwk/main.go +++ b/tools/cmd/genjwk/main.go @@ -314,7 +314,7 @@ func generateObject(o *codegen.Output, kt *KeyType, obj *codegen.Object) error { o.L("h.mu.Lock()") o.L("defer h.mu.Unlock()") o.L("return h.setNoLock(name, value)") - o.L("}") + o.L(`}`) o.LL("func (h *%s) setNoLock(name string, value interface{}) error {", structName) o.L("switch name {") @@ -646,6 +646,20 @@ func generateGenericHeaders(fields codegen.FieldList) error { o.LL("// Remove removes the field associated with the specified key.") o.L("// There is no way to remove the `kty` (key type). You will ALWAYS be left with one field in a jwk.Key.") o.L("Remove(string) error") + o.L("// Validate performs _minimal_ checks if the data stored in the key are valid.") + o.L("// By minimal, we mean that it does not check if the key is valid for use in") + o.L("// cryptographic operations. For example, it does not check if an RSA key's") + o.L("// `e` field is a valid exponent, or if the `n` field is a valid modulus.") + o.L("// Instead, it checks for things such as the _presence_ of some required fields,") + o.L("// or if certain keys' values are of particular length.") + o.L("//") + o.L("// Note that depending on th underlying key type, use of this method requires") + o.L("// that multiple fields in the key are properly populated. For example, an EC") + o.L("// key's \"x\", \"y\" fields cannot be validated unless the \"crv\" field is populated first.") + o.L("//") + o.L("// Validate is never called by `UnmarshalJSON()` or `Set`. It must explicitly be") + o.L("// called by the user") + o.L("Validate() error") o.LL("// Raw creates the corresponding raw key. For example,") o.L("// EC types would create *ecdsa.PublicKey or *ecdsa.PrivateKey,") o.L("// and OctetSeq types create a []byte key.")