Skip to content

Commit

Permalink
Add (jwk.Key).Validate (#1005)
Browse files Browse the repository at this point in the history
* Add (jwk.Key).Validate

* Appease linter

* Update WithValidateKey, and add tests+docs

* appease linter
  • Loading branch information
lestrrat authored Oct 27, 2023
1 parent 1ecc78f commit 1e3b478
Show file tree
Hide file tree
Showing 17 changed files with 451 additions and 30 deletions.
16 changes: 16 additions & 0 deletions Changes
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,22 @@ v2.0.16 UNRELEASED
fail. Therefore this in itself should not pose any security risk, albeit
allowing some illegally formated messages to be verified.

* [jwk] `jwk.Key` objects now have a `Validate()` method to validate the data
stored in the keys. However, this still does not necessarily mean that the key's
are valid for use in cryptographic operations. If `Validate()` is successful,
it only means that the keys are in the right _format_, including the presence
of required fields and that certain fields are have proper length, etc.

[New Features]
* [jws] Added `jws.WithValidateKey()` to force calling `key.Validate()` before
signing or verification.

* [jws] `jws.Sign()` now returns a special type of error that can hold the
individual errors from the signers. The stringification is still the same
as before to preserve backwards compatibility.

* [jwk] Added `jwk.IsKeyValidationError` that checks if an error is an error
from `key.Validate()`.

v2.0.15 19 20 Oct 2023
[Bug fixes]
Expand Down
45 changes: 45 additions & 0 deletions jwk/ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,3 +226,48 @@ func (k ecdsaPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
base64.EncodeToString(ybuf),
), nil
}

func ecdsaValidateKey(k interface {
Crv() jwa.EllipticCurveAlgorithm
X() []byte
Y() []byte
}, checkPrivate bool) error {
crv, ok := ecutil.CurveForAlgorithm(k.Crv())
if !ok {
return fmt.Errorf(`invalid curve algorithm %q`, k.Crv())
}

keySize := ecutil.CalculateKeySize(crv)
if x := k.X(); len(x) != keySize {
return fmt.Errorf(`invalid "x" length (%d) for curve %q`, len(x), crv.Params().Name)
}

if y := k.Y(); len(y) != keySize {
return fmt.Errorf(`invalid "y" length (%d) for curve %q`, len(y), crv.Params().Name)
}

if checkPrivate {
if priv, ok := k.(interface{ D() []byte }); ok {
if len(priv.D()) != keySize {
return fmt.Errorf(`invalid "d" length (%d) for curve %q`, len(priv.D()), crv.Params().Name)
}
} else {
return fmt.Errorf(`missing "d" value`)
}
}
return nil
}

func (k *ecdsaPrivateKey) Validate() error {
if err := ecdsaValidateKey(k, true); err != nil {
return NewKeyValidationError(fmt.Errorf(`jwk.ECDSAPrivateKey: %w`, err))
}
return nil
}

func (k *ecdsaPublicKey) Validate() error {
if err := ecdsaValidateKey(k, false); err != nil {
return NewKeyValidationError(fmt.Errorf(`jwk.ECDSAPublicKey: %w`, err))
}
return nil
}
14 changes: 14 additions & 0 deletions jwk/interface_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions jwk/jwk.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"errors"
"fmt"
"io"
"math/big"
Expand Down Expand Up @@ -757,3 +758,32 @@ func IsPrivateKey(k Key) (bool, error) {
}
return false, fmt.Errorf("jwk.IsPrivateKey: %T is not an asymmetric key", k)
}

type keyValidationError struct {
err error
}

func (e *keyValidationError) Error() string {
return fmt.Sprintf(`key validation failed: %s`, e.err)
}

func (e *keyValidationError) Unwrap() error {
return e.err
}

func (e *keyValidationError) Is(target error) bool {
_, ok := target.(*keyValidationError)
return ok
}

// NewKeyValidationError wraps the given error with an error that denotes
// `key.Validate()` has failed. This error type should ONLY be used as
// return value from the `Validate()` method.
func NewKeyValidationError(err error) error {
return &keyValidationError{err: err}
}

func IsKeyValidationError(err error) bool {
var kve keyValidationError
return errors.Is(err, &kve)
}
27 changes: 22 additions & 5 deletions jwk/jwk_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"testing"

"github.com/lestrrat-go/jwx/v2/cert"
"github.com/lestrrat-go/jwx/v2/internal/base64"
"github.com/lestrrat-go/jwx/v2/internal/json"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -129,9 +130,18 @@ func TestIterator(t *testing.T) {
{
Extras: map[string]interface{}{
ECDSACrvKey: jwa.P256,
ECDSAXKey: []byte("MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4"),
ECDSAYKey: []byte("4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM"),
ECDSADKey: []byte("870MB6gfuTJ4HtUnUvYMyJpr5eUZNP4Bk43bVdj3eAE"),
ECDSAXKey: (func() []byte {
s, _ := base64.DecodeString("MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4")
return s
})(),
ECDSAYKey: (func() []byte {
s, _ := base64.DecodeString("4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM")
return s
})(),
ECDSADKey: (func() []byte {
s, _ := base64.DecodeString("870MB6gfuTJ4HtUnUvYMyJpr5eUZNP4Bk43bVdj3eAE")
return s
})(),
},
Func: func() Key {
return newECDSAPrivateKey()
Expand All @@ -140,8 +150,14 @@ func TestIterator(t *testing.T) {
{
Extras: map[string]interface{}{
ECDSACrvKey: jwa.P256,
ECDSAXKey: []byte("MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4"),
ECDSAYKey: []byte("4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM"),
ECDSAXKey: (func() []byte {
s, _ := base64.DecodeString("MKBCTNIcKUSDii11ySs3526iDZ8AiTo7Tu6KPAqv7D4")
return s
})(),
ECDSAYKey: (func() []byte {
s, _ := base64.DecodeString("4Etl6SRW2YiLUrN5vfvVHuhp7x8PxltmWWlbbM4IFyM")
return s
})(),
},
Func: func() Key {
return newECDSAPublicKey()
Expand Down Expand Up @@ -184,6 +200,7 @@ func TestIterator(t *testing.T) {
}

if !assert.NoError(t, json.Unmarshal(buf, key2), `json.Unmarshal should succeed`) {
t.Logf("%s", buf)
return
}

Expand Down
51 changes: 51 additions & 0 deletions jwk/jwk_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2217,3 +2217,54 @@ func TestGH947(t *testing.T) {
var exported []byte
require.Error(t, k.Raw(&exported), `(okpkey).Raw with 0-length OKP key should fail`)
}

func TestValidation(t *testing.T) {
{
key, err := jwxtest.GenerateRsaJwk()
require.NoError(t, err, `jwx.GenerateRsaJwk should succeed`)
require.NoError(t, key.Validate(), `key.Validate should succeed (vanilla key)`)

require.NoError(t, key.Set(jwk.RSADKey, []byte(nil)), `key.Set should succeed`)
require.Error(t, key.Validate(), `key.Validate should fail`)
}

{
key, err := jwxtest.GenerateEcdsaJwk()
require.NoError(t, err, `jwx.GenerateEcdsaJwk should succeed`)
require.NoError(t, key.Validate(), `key.Validate should succeed`)

x := key.(jwk.ECDSAPrivateKey).X()
require.NoError(t, key.Set(jwk.ECDSAXKey, x[:len(x)/2]), `key.Set should succeed`)
require.Error(t, key.Validate(), `key.Validate should fail`)

require.NoError(t, key.Set(jwk.ECDSAXKey, x), `key.Set should succeed`)
require.NoError(t, key.Validate(), `key.Validate should succeed`)

require.NoError(t, key.Set(jwk.ECDSADKey, []byte(nil)), `key.Set should succeed`)
require.Error(t, key.Validate(), `key.Validate should fail`)
}

{
key, err := jwxtest.GenerateEd25519Jwk()
require.NoError(t, err, `jwx.GenerateEd25519Jwk should succeed`)
require.NoError(t, key.Validate(), `key.Validate should succeed`)
x := key.(jwk.OKPPrivateKey).X()
require.NoError(t, key.Set(jwk.OKPXKey, []byte(nil)), `key.Set should succeed`)
require.Error(t, key.Validate(), `key.Validate should fail`)

require.NoError(t, key.Set(jwk.OKPXKey, x), `key.Set should succeed`)
require.NoError(t, key.Validate(), `key.Validate should succeed`)

require.NoError(t, key.Set(jwk.OKPDKey, []byte(nil)), `key.Set should succeed`)
require.Error(t, key.Validate(), `key.Validate should fail`)
}

{
key, err := jwxtest.GenerateSymmetricJwk()
require.NoError(t, err, `jwx.GenerateSymmetricJwk should succeed`)
require.NoError(t, key.Validate(), `key.Validate should succeed`)

require.NoError(t, key.Set(jwk.SymmetricOctetsKey, []byte(nil)), `key.Set should succeed`)
require.Error(t, key.Validate(), `key.Validate should fail`)
}
}
38 changes: 38 additions & 0 deletions jwk/okp.go
Original file line number Diff line number Diff line change
Expand Up @@ -187,3 +187,41 @@ func (k okpPrivateKey) Thumbprint(hash crypto.Hash) ([]byte, error) {
base64.EncodeToString(k.x),
), nil
}

func validateOKPKey(key interface {
Crv() jwa.EllipticCurveAlgorithm
X() []byte
}) error {
if key.Crv() == jwa.InvalidEllipticCurve {
return fmt.Errorf(`invalid curve algorithm`)
}

if len(key.X()) == 0 {
return fmt.Errorf(`missing "x" field`)
}

if priv, ok := key.(interface{ D() []byte }); ok {
if len(priv.D()) == 0 {
return fmt.Errorf(`missing "d" field`)
}
}
return nil
}

func (k *okpPublicKey) Validate() error {
k.mu.RLock()
defer k.mu.RUnlock()
if err := validateOKPKey(k); err != nil {
return NewKeyValidationError(fmt.Errorf(`jwk.OKPPublicKey: %w`, err))
}
return nil
}

func (k *okpPrivateKey) Validate() error {
k.mu.RLock()
defer k.mu.RUnlock()
if err := validateOKPKey(k); err != nil {
return NewKeyValidationError(fmt.Errorf(`jwk.OKPPrivateKey: %w`, err))
}
return nil
}
40 changes: 40 additions & 0 deletions jwk/rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -241,3 +241,43 @@ func rsaThumbprint(hash crypto.Hash, key *rsa.PublicKey) ([]byte, error) {
}
return h.Sum(nil), nil
}

func validateRSAKey(key interface {
N() []byte
E() []byte
}, checkPrivate bool) error {
if len(key.N()) == 0 {
// Ideally we would like to check for the actual length, but unlike
// EC keys, we have nothing in the key itself that will tell us
// how many bits this key should have.
return fmt.Errorf(`missing "n" value`)
}
if len(key.E()) == 0 {
return fmt.Errorf(`missing "e" value`)
}
if checkPrivate {
if priv, ok := key.(interface{ D() []byte }); ok {
if len(priv.D()) == 0 {
return fmt.Errorf(`missing "d" value`)
}
} else {
return fmt.Errorf(`missing "d" value`)
}
}

return nil
}

func (k *rsaPrivateKey) Validate() error {
if err := validateRSAKey(k, true); err != nil {
return NewKeyValidationError(fmt.Errorf(`jwk.RSAPrivateKey: %w`, err))
}
return nil
}

func (k *rsaPublicKey) Validate() error {
if err := validateRSAKey(k, false); err != nil {
return NewKeyValidationError(fmt.Errorf(`jwk.RSAPublicKey: %w`, err))
}
return nil
}
7 changes: 7 additions & 0 deletions jwk/symmetric.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,10 @@ func (k *symmetricKey) PublicKey() (Key, error) {
}
return newKey, nil
}

func (k *symmetricKey) Validate() error {
if len(k.Octets()) == 0 {
return NewKeyValidationError(fmt.Errorf(`jwk.SymmetricKey: missing "k" field`))
}
return nil
}
Loading

0 comments on commit 1e3b478

Please sign in to comment.