From c016a86d349633c127e70bb4a4c0553edf498ac3 Mon Sep 17 00:00:00 2001 From: Sergei Trofimov Date: Thu, 25 May 2023 11:43:31 +0100 Subject: [PATCH] Add COSE_Key support Add support for COSE_Key structure, and associated enums for key type, ECC curve, and key ops, as defined in RFC8152. Signed-off-by: Sergei Trofimov --- errors.go | 4 + go.mod | 7 +- go.sum | 9 +- key.go | 809 ++++++++++++++++++++++++++++++++++++++++++++++++++++ key_test.go | 657 ++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 1475 insertions(+), 11 deletions(-) create mode 100644 key.go create mode 100644 key_test.go diff --git a/errors.go b/errors.go index 8c240e2..7d16741 100644 --- a/errors.go +++ b/errors.go @@ -14,4 +14,8 @@ var ( ErrUnavailableHashFunc = errors.New("hash function is not available") ErrVerification = errors.New("verification error") ErrInvalidPubKey = errors.New("invalid public key") + ErrInvalidPrivKey = errors.New("invalid private key") + ErrNotPrivKey = errors.New("not a private key") + ErrSignOpNotSupported = errors.New("sign key_op not supported by key") + ErrVerifyOpNotSupported = errors.New("verify key_op not supported by key") ) diff --git a/go.mod b/go.mod index 82106a5..6cb8a1b 100644 --- a/go.mod +++ b/go.mod @@ -2,13 +2,14 @@ module github.com/veraison/go-cose go 1.18 -require github.com/fxamacker/cbor/v2 v2.4.0 +require ( + github.com/fxamacker/cbor/v2 v2.4.0 + github.com/stretchr/testify v1.8.3 +) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/stretchr/objx v0.5.0 // indirect - github.com/stretchr/testify v1.8.3 // indirect github.com/x448/float16 v0.8.4 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d39fd90..c2215f6 100644 --- a/go.sum +++ b/go.sum @@ -1,21 +1,14 @@ -github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fxamacker/cbor/v2 v2.4.0 h1:ri0ArlOR+5XunOP8CRUowT0pSJOwhW098ZCUyskZD88= github.com/fxamacker/cbor/v2 v2.4.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/key.go b/key.go new file mode 100644 index 0000000..a884fef --- /dev/null +++ b/key.go @@ -0,0 +1,809 @@ +package cose + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "errors" + "fmt" + "math/big" + "strconv" + + cbor "github.com/fxamacker/cbor/v2" +) + +const ( + // An inviald key_op value + KeyOpInvalid KeyOp = 0 + + // The key is used to create signatures. Requires private key fields. + KeyOpSign KeyOp = 1 + + // The key is used for verification of signatures. + KeyOpVerify KeyOp = 2 + + // The key is used for key transport encryption. + KeyOpEncrypt KeyOp = 3 + + // The key is used for key transport decryption. Requires private key fields. + KeyOpDecrypt KeyOp = 4 + + // The key is used for key wrap encryption. + KeyOpWrapKey KeyOp = 5 + + // The key is used for key wrap decryption. + KeyOpUnwrapKey KeyOp = 6 + + // The key is used for deriving keys. Requires private key fields. + KeyOpDeriveKey KeyOp = 7 + + // The key is used for deriving bits not to be used as a key. Requires + // private key fields. + KeyOpDeriveBits KeyOp = 8 + + // The key is used for creating MACs. + KeyOpMACCreate KeyOp = 9 + + // The key is used for validating MACs. + KeyOpMACVerify KeyOp = 10 +) + +// KeyOp represents a key_ops value used to restrict purposes for which a Key +// may be used. +type KeyOp int64 + +// KeyOpFromString returns the KeyOp corresponding to the specified name. +// The values are taken from https://www.rfc-editor.org/rfc/rfc7517#section-4.3 +func KeyOpFromString(val string) (KeyOp, error) { + switch val { + case "sign": + return KeyOpSign, nil + case "verify": + return KeyOpVerify, nil + case "encrypt": + return KeyOpEncrypt, nil + case "decrypt": + return KeyOpDecrypt, nil + case "wrapKey": + return KeyOpWrapKey, nil + case "unwrapKey": + return KeyOpUnwrapKey, nil + case "deriveKey": + return KeyOpDeriveKey, nil + case "deriveBits": + return KeyOpDeriveBits, nil + default: + return KeyOpInvalid, fmt.Errorf("unknown key_ops value %q", val) + } +} + +func (ko KeyOp) String() string { + switch ko { + case KeyOpSign: + return "sign" + case KeyOpVerify: + return "verify" + case KeyOpEncrypt: + return "encrypt" + case KeyOpDecrypt: + return "decrypt" + case KeyOpWrapKey: + return "wrapKey" + case KeyOpUnwrapKey: + return "unwrapKey" + case KeyOpDeriveKey: + return "deriveKey" + case KeyOpDeriveBits: + return "deriveBits" + case KeyOpMACCreate: + return "MAC create" + case KeyOpMACVerify: + return "MAC verify" + default: + return "unknown key_op value " + strconv.Itoa(int(ko)) + } +} + +func (ko KeyOp) IsSupported() bool { + return ko >= 1 && ko <= 10 +} + +// MarshalCBOR marshals the KeyOp as a CBOR int. +func (ko KeyOp) MarshalCBOR() ([]byte, error) { + return encMode.Marshal(int64(ko)) +} + +// UnmarshalCBOR populates the KeyOp from the provided CBOR value (must be int +// or tstr). +func (ko *KeyOp) UnmarshalCBOR(data []byte) error { + var raw intOrStr + + if err := raw.UnmarshalCBOR(data); err != nil { + return fmt.Errorf("invalid key_ops value %w", err) + } + + if raw.IsString() { + v, err := KeyOpFromString(raw.String()) + if err != nil { + return err + } + + *ko = v + } else { + v := raw.Int() + *ko = KeyOp(v) + + if !ko.IsSupported() { + return fmt.Errorf("unknown key_ops value %d", v) + } + } + + return nil +} + +// KeyType identifies the family of keys represented by the associated Key. +// This determines which files within the Key must be set in order for it to be +// valid. +type KeyType int64 + +const ( + // Invlaid key type + KeyTypeInvalid KeyType = 0 + // Octet Key Pair + KeyTypeOkp KeyType = 1 + // Elliptic Curve Keys w/ x- and y-coordinate pair + KeyTypeEc2 KeyType = 2 + // Symmetric Keys + KeyTypeSymmetric KeyType = 4 +) + +// KeyTypeFromString returns the KeyType corresponding to the specified name. +func KeyTypeFromString(v string) (KeyType, error) { + switch v { + case "OKP": + return KeyTypeOkp, nil + case "EC2": + return KeyTypeEc2, nil + case "Symmetric": + return KeyTypeSymmetric, nil + default: + return KeyTypeInvalid, fmt.Errorf("unknown key type value %q", v) + } +} + +func (kt KeyType) String() string { + switch kt { + case KeyTypeOkp: + return "OKP" + case KeyTypeEc2: + return "EC2" + case KeyTypeSymmetric: + return "Symmetric" + default: + return "unknown key type value " + strconv.Itoa(int(kt)) + } +} + +// MarshalCBOR marshals the KeyType as a CBOR int. +func (kt KeyType) MarshalCBOR() ([]byte, error) { + return encMode.Marshal(int(kt)) +} + +// UnmarshalCBOR populates the KeyType from the provided CBOR value (must be +// int or tstr). +func (kt *KeyType) UnmarshalCBOR(data []byte) error { + var raw intOrStr + + if err := raw.UnmarshalCBOR(data); err != nil { + return fmt.Errorf("invalid key type value: %w", err) + } + + if raw.IsString() { + v, err := KeyTypeFromString(raw.String()) + + if err != nil { + return err + } + + *kt = v + } else { + v := raw.Int() + + if v == 0 { + // 0 is reserved, and so can never be valid + return fmt.Errorf("invalid key type value 0") + } + + if v > 4 || v < 0 || v == 3 { + return fmt.Errorf("unknown key type value %d", v) + } + + *kt = KeyType(v) + } + + return nil +} + +const ( + + // Invalid/unrecognised curve + CurveInvalid Curve = 0 + + // NIST P-256 also known as secp256r1 + CurveP256 Curve = 1 + + // NIST P-384 also known as secp384r1 + CurveP384 Curve = 2 + + // NIST P-521 also known as secp521r1 + CurveP521 Curve = 3 + + // X25519 for use w/ ECDH only + CurveX25519 Curve = 4 + + // X448 for use w/ ECDH only + CurveX448 Curve = 5 + + // Ed25519 for use /w EdDSA only + CurveEd25519 Curve = 6 + + // Ed448 for use /w EdDSA only + CurveEd448 Curve = 7 +) + +// Curve represents the EC2/OKP key's curve. See: +// https://datatracker.ietf.org/doc/html/rfc8152#section-13.1 +type Curve int64 + +func CurveFromString(v string) (Curve, error) { + switch v { + case "P-256": + return CurveP256, nil + case "P-384": + return CurveP384, nil + case "P-521": + return CurveP521, nil + case "X25519": + return CurveX25519, nil + case "X448": + return CurveX448, nil + case "Ed25519": + return CurveEd25519, nil + case "Ed448": + return CurveEd448, nil + default: + return CurveInvalid, fmt.Errorf("unknown curve value %q", v) + } +} + +func (c Curve) String() string { + switch c { + case CurveP256: + return "P-256" + case CurveP384: + return "P-384" + case CurveP521: + return "P-521" + case CurveX25519: + return "X25519" + case CurveX448: + return "X448" + case CurveEd25519: + return "Ed25519" + case CurveEd448: + return "Ed448" + default: + return "unknown curve value " + strconv.Itoa(int(c)) + } +} + +// MarshalCBOR marshals the KeyType as a CBOR int. +func (c Curve) MarshalCBOR() ([]byte, error) { + return encMode.Marshal(int(c)) +} + +// UnmarshalCBOR populates the KeyType from the provided CBOR value (must be +// int or tstr). +func (c *Curve) UnmarshalCBOR(data []byte) error { + var raw intOrStr + + if err := raw.UnmarshalCBOR(data); err != nil { + return fmt.Errorf("invalid curve value: %w", err) + } + + if raw.IsString() { + v, err := CurveFromString(raw.String()) + + if err != nil { + return err + } + + *c = v + } else { + v := raw.Int() + + if v < 1 || v > 7 { + return fmt.Errorf("unknown curve value %d", v) + } + + *c = Curve(v) + } + + return nil +} + +// Key represents a COSE_Key structure, as defined by RFC8152. +// Note: currently, this does NOT support RFC8230 (RSA algorithms). +type Key struct { + // Common parameters. These are independent of the key type. Only + // KeyType common parameter MUST be set. + + // KeyType identifies the family of keys for this structure, and thus, + // which of the key-type-specific parameters need to be set. + KeyType KeyType `cbor:"1,keyasint"` + // KeyID is the identification value matched to the kid in the message. + KeyID []byte `cbor:"2,keyasint,omitempty"` + // KeyOps can be set to restrict the set of operations that the Key is used for. + KeyOps []KeyOp `cbor:"4,keyasint,omitempty"` + // BaseIV is the Base IV to be xor-ed with Partial IVs. + BaseIV []byte `cbor:"5,keyasint,omitempty"` + + // Algorithm is used to restrict the algorithm that is used with the + // key. If it is set, the application MUST verify that it matches the + // algorithm for which the Key is being used. + Algorithm Algorithm `cbor:"-"` + // Curve is EC identifier -- taken form "COSE Elliptic Curves" IANA registry. + // Populated from keyStruct.RawKeyParam when key type is EC2 or OKP. + Curve Curve `cbor:"-"` + // K is the key value. Populated from keyStruct.RawKeyParam when key + // type is Symmetric. + K []byte `cbor:"-"` + + // EC2/OKP params + + // X is the x-coordinate + X []byte `cbor:"-2,keyasint,omitempty"` + // Y is the y-coordinate (sign bits are not supported) + Y []byte `cbor:"-3,keyasint,omitempty"` + // D is the private key + D []byte `cbor:"-4,keyasint,omitempty"` +} + +// NewOkpKey returns a Key created using the provided Octet Key Pair data. +func NewOkpKey(alg Algorithm, x, d []byte) (*Key, error) { + if alg != AlgorithmEd25519 { + return nil, fmt.Errorf("unsupported algorithm %q", alg) + } + + key := &Key{ + KeyType: KeyTypeOkp, + Algorithm: alg, + Curve: CurveEd25519, + X: x, + D: d, + } + return key, key.Validate() +} + +// NewEc2Key returns a Key created using the provided elliptic curve key +// data. +func NewEc2Key(alg Algorithm, x, y, d []byte) (*Key, error) { + var curve Curve + + switch alg { + case AlgorithmES256: + curve = CurveP256 + case AlgorithmES384: + curve = CurveP384 + case AlgorithmES512: + curve = CurveP521 + default: + return nil, fmt.Errorf("unsupported algorithm %q", alg) + } + + key := &Key{ + KeyType: KeyTypeEc2, + Algorithm: alg, + Curve: curve, + X: x, + Y: y, + D: d, + } + return key, key.Validate() +} + +// NewSymmetricKey returns a Key created using the provided Symmetric key +// bytes. +func NewSymmetricKey(k []byte) (*Key, error) { + key := &Key{ + KeyType: KeyTypeSymmetric, + K: k, + } + return key, key.Validate() +} + +// NewKeyFromPublic returns a Key created using the provided crypto.PublicKey +// and Algorithm. +func NewKeyFromPublic(alg Algorithm, pub crypto.PublicKey) (*Key, error) { + switch alg { + case AlgorithmES256, AlgorithmES384, AlgorithmES512: + vk, ok := pub.(*ecdsa.PublicKey) + if !ok { + return nil, fmt.Errorf("%v: %w", alg, ErrInvalidPubKey) + } + + return NewEc2Key(alg, vk.X.Bytes(), vk.Y.Bytes(), nil) + case AlgorithmEd25519: + vk, ok := pub.(ed25519.PublicKey) + if !ok { + return nil, fmt.Errorf("%v: %w", alg, ErrInvalidPubKey) + } + + return NewOkpKey(alg, []byte(vk), nil) + default: + return nil, ErrAlgorithmNotSupported + } +} + +// NewKeyFromPrivate returns a Key created using provided crypto.PrivateKey +// and Algorithm. +func NewKeyFromPrivate(alg Algorithm, priv crypto.PrivateKey) (*Key, error) { + switch alg { + case AlgorithmES256, AlgorithmES384, AlgorithmES512: + sk, ok := priv.(*ecdsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("%v: %w", alg, ErrInvalidPrivKey) + } + + return NewEc2Key(alg, sk.X.Bytes(), sk.Y.Bytes(), sk.D.Bytes()) + case AlgorithmEd25519: + sk, ok := priv.(ed25519.PrivateKey) + if !ok { + return nil, fmt.Errorf("%v: %w", alg, ErrInvalidPrivKey) + } + return NewOkpKey(alg, []byte(sk[32:]), []byte(sk[:32])) + default: + return nil, ErrAlgorithmNotSupported + } +} + +// Validate ensures that the parameters set inside the Key are internally +// consistent (e.g., that the key type is appropriate to the curve.) +func (k Key) Validate() error { + switch k.KeyType { + case KeyTypeEc2: + switch k.Curve { + case CurveP256, CurveP384, CurveP521: + // ok + default: + return fmt.Errorf( + "EC2 curve must be P-256, P-384, or P-521; found %q", + k.Curve.String(), + ) + } + case KeyTypeOkp: + switch k.Curve { + case CurveX25519, CurveX448, CurveEd25519, CurveEd448: + // ok + default: + return fmt.Errorf( + "OKP curve must be X25519, X448, Ed25519, or Ed448; found %q", + k.Curve.String(), + ) + } + case KeyTypeSymmetric: + default: + return errors.New(k.KeyType.String()) + } + + // If Algorithm is set, it must match the specified key parameters. + if k.Algorithm != AlgorithmInvalid { + expectedAlg, err := k.deriveAlgorithm() + if err != nil { + return err + } + + if k.Algorithm != expectedAlg { + return fmt.Errorf( + "found algorithm %q (expected %q)", + k.Algorithm.String(), + expectedAlg.String(), + ) + } + } + + return nil +} + +type keyalias Key + +type mashaledKey struct { + keyalias + + // RawAlgorithm contains the raw Algorithm value, this is necessary + // because cbor library ignores omitempty on types that implement the + // cbor.Marshaler interface. + RawAlgorithm cbor.RawMessage `cbor:"3,keyasint,omitempty"` + + // RawKeyParam contains the raw CBOR encoded data for the label -1. + // Depending on the KeyType this is used to populate either Curve or K + // below. + RawKeyParam cbor.RawMessage `cbor:"-1,keyasint,omitempty"` +} + +// MarshalCBOR encodes Key into a COSE_Key object. +func (k *Key) MarshalCBOR() ([]byte, error) { + tmp := mashaledKey{ + keyalias: keyalias(*k), + } + var err error + + if k.KeyType == KeyTypeSymmetric { + if tmp.RawKeyParam, err = encMode.Marshal(k.K); err != nil { + return nil, err + } + } else if k.KeyType == KeyTypeEc2 || k.KeyType == KeyTypeOkp { + if tmp.RawKeyParam, err = encMode.Marshal(k.Curve); err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("invalid key type: %q", k.KeyType.String()) + } + + if k.Algorithm != AlgorithmInvalid { + if tmp.RawAlgorithm, err = encMode.Marshal(k.Algorithm); err != nil { + return nil, err + } + } + + return encMode.Marshal(tmp) +} + +// UnmarshalCBOR decodes a COSE_Key object into Key. +func (k *Key) UnmarshalCBOR(data []byte) error { + var tmp mashaledKey + + if err := decMode.Unmarshal(data, &tmp); err != nil { + return err + } + *k = Key(tmp.keyalias) + + if tmp.RawAlgorithm != nil { + if err := decMode.Unmarshal(tmp.RawAlgorithm, &k.Algorithm); err != nil { + return err + } + } + + switch k.KeyType { + case KeyTypeEc2: + if tmp.RawKeyParam == nil { + return errors.New("missing Curve parameter (required for EC2 key type)") + } + + if err := decMode.Unmarshal(tmp.RawKeyParam, &k.Curve); err != nil { + return err + } + case KeyTypeOkp: + if tmp.RawKeyParam == nil { + return errors.New("missing Curve parameter (required for OKP key type)") + } + + if err := decMode.Unmarshal(tmp.RawKeyParam, &k.Curve); err != nil { + return err + } + case KeyTypeSymmetric: + if tmp.RawKeyParam == nil { + return errors.New("missing K parameter (required for Symmetric key type)") + } + + if err := decMode.Unmarshal(tmp.RawKeyParam, &k.K); err != nil { + return err + } + default: + // this should not be reachable as KeyType.UnmarshalCBOR would + // result in an error during decMode.Unmarshal() above, if the + // value in the data doesn't correspond to one of the above + // types. + return fmt.Errorf("unexpected key type %q", k.KeyType.String()) + } + + return k.Validate() +} + +// PublicKey returns a crypto.PublicKey generated using Key's parameters. +func (k *Key) PublicKey() (crypto.PublicKey, error) { + alg, err := k.deriveAlgorithm() + if err != nil { + return nil, err + } + + switch alg { + case AlgorithmES256, AlgorithmES384, AlgorithmES512: + var curve elliptic.Curve + + switch alg { + case AlgorithmES256: + curve = elliptic.P256() + case AlgorithmES384: + curve = elliptic.P384() + case AlgorithmES512: + curve = elliptic.P521() + } + + pub := &ecdsa.PublicKey{Curve: curve, X: new(big.Int), Y: new(big.Int)} + pub.X.SetBytes(k.X) + pub.Y.SetBytes(k.Y) + + return pub, nil + case AlgorithmEd25519: + return ed25519.PublicKey(k.X), nil + default: + return nil, ErrAlgorithmNotSupported + } +} + +// PrivateKey returns a crypto.PrivateKey generated using Key's parameters. +func (k *Key) PrivateKey() (crypto.PrivateKey, error) { + alg, err := k.deriveAlgorithm() + if err != nil { + return nil, err + } + + if len(k.D) == 0 { + return nil, ErrNotPrivKey + } + + switch alg { + case AlgorithmES256, AlgorithmES384, AlgorithmES512: + var curve elliptic.Curve + + switch alg { + case AlgorithmES256: + curve = elliptic.P256() + case AlgorithmES384: + curve = elliptic.P384() + case AlgorithmES512: + curve = elliptic.P521() + } + + priv := &ecdsa.PrivateKey{ + PublicKey: ecdsa.PublicKey{Curve: curve, X: new(big.Int), Y: new(big.Int)}, + D: new(big.Int), + } + priv.X.SetBytes(k.X) + priv.Y.SetBytes(k.Y) + priv.D.SetBytes(k.D) + + return priv, nil + case AlgorithmEd25519: + buf := make([]byte, ed25519.PrivateKeySize) + + copy(buf, k.D) + copy(buf[32:], k.X) + + return ed25519.PrivateKey(buf), nil + default: + return nil, ErrAlgorithmNotSupported + } +} + +// GetAlgorithm returns the Algorithm associated with Key. If Key.Algorithm is +// set, that is what is returned. Otherwise, the algorithm is inferred using +// Key.Curve. This method does NOT validate that Key.Algorithm, if set, aligns +// with Key.Curve. +func (k *Key) GetAlgorithm() (Algorithm, error) { + if k.Algorithm != AlgorithmInvalid { + return k.Algorithm, nil + } + + return k.deriveAlgorithm() +} + +// Signer returns a Signer created using Key. +func (k *Key) Signer() (Signer, error) { + if err := k.Validate(); err != nil { + return nil, err + } + + if k.KeyOps != nil { + signFound := false + + for _, kop := range k.KeyOps { + if kop == KeyOpSign { + signFound = true + break + } + } + + if !signFound { + return nil, ErrSignOpNotSupported + } + } + + priv, err := k.PrivateKey() + if err != nil { + return nil, err + } + + alg, err := k.GetAlgorithm() + if err != nil { + return nil, err + } + + var signer crypto.Signer + var ok bool + + switch alg { + case AlgorithmES256, AlgorithmES384, AlgorithmES512: + signer, ok = priv.(*ecdsa.PrivateKey) + if !ok { + return nil, ErrInvalidPrivKey + } + case AlgorithmEd25519: + signer, ok = priv.(ed25519.PrivateKey) + if !ok { + return nil, ErrInvalidPrivKey + } + default: + return nil, ErrAlgorithmNotSupported + } + + return NewSigner(alg, signer) +} + +// Verifier returns a Verifier created using Key. +func (k *Key) Verifier() (Verifier, error) { + if err := k.Validate(); err != nil { + return nil, err + } + + if k.KeyOps != nil { + verifyFound := false + + for _, kop := range k.KeyOps { + if kop == KeyOpVerify { + verifyFound = true + break + } + } + + if !verifyFound { + return nil, ErrVerifyOpNotSupported + } + } + + pub, err := k.PublicKey() + if err != nil { + return nil, err + } + + alg, err := k.GetAlgorithm() + if err != nil { + return nil, err + } + + return NewVerifier(alg, pub) +} + +// deriveAlgorithm derives the intended algorithm for the key from its curve. +func (k *Key) deriveAlgorithm() (Algorithm, error) { + switch k.KeyType { + case KeyTypeEc2, KeyTypeOkp: + switch k.Curve { + case CurveP256: + return AlgorithmES256, nil + case CurveP384: + return AlgorithmES384, nil + case CurveP521: + return AlgorithmES512, nil + case CurveEd25519: + return AlgorithmEd25519, nil + default: + return AlgorithmInvalid, fmt.Errorf("unsupported curve %q", k.Curve.String()) + } + default: + // Symmetric algorithms are not supported in the current inmplementation. + return AlgorithmInvalid, fmt.Errorf("unexpected key type %q", k.KeyType.String()) + } +} diff --git a/key_test.go b/key_test.go new file mode 100644 index 0000000..8d6f129 --- /dev/null +++ b/key_test.go @@ -0,0 +1,657 @@ +package cose + +import ( + "crypto" + "crypto/ecdsa" + "crypto/ed25519" + "crypto/elliptic" + "crypto/rand" + "testing" + + "github.com/fxamacker/cbor/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_KeyOp(t *testing.T) { + + tvs := []struct { + Name string + Value KeyOp + }{ + {"sign", KeyOpSign}, + {"verify", KeyOpVerify}, + {"encrypt", KeyOpEncrypt}, + {"decrypt", KeyOpDecrypt}, + {"wrapKey", KeyOpWrapKey}, + {"unwrapKey", KeyOpUnwrapKey}, + {"deriveKey", KeyOpDeriveKey}, + {"deriveBits", KeyOpDeriveBits}, + } + + for _, tv := range tvs { + assert.Equal(t, tv.Name, tv.Value.String()) + + data, err := cbor.Marshal(tv.Name) + require.NoError(t, err) + + var ko KeyOp + err = cbor.Unmarshal(data, &ko) + require.NoError(t, err) + assert.Equal(t, tv.Value, ko) + + data, err = cbor.Marshal(int(tv.Value)) + require.NoError(t, err) + + err = cbor.Unmarshal(data, &ko) + require.NoError(t, err) + assert.Equal(t, tv.Value, ko) + + } + + var ko KeyOp + + data := []byte{0x20} + err := ko.UnmarshalCBOR(data) + assert.EqualError(t, err, "unknown key_ops value -1") + + data = []byte{0x18, 0xff} + err = ko.UnmarshalCBOR(data) + assert.EqualError(t, err, "unknown key_ops value 255") + + data = []byte{0x63, 0x66, 0x6f, 0x6f} + err = ko.UnmarshalCBOR(data) + assert.EqualError(t, err, "unknown key_ops value \"foo\"") + + data = []byte{0x40} + err = ko.UnmarshalCBOR(data) + assert.EqualError(t, err, "invalid key_ops value must be int or string, found []uint8") + + assert.Equal(t, "MAC create", KeyOpMACCreate.String()) + assert.Equal(t, "MAC verify", KeyOpMACVerify.String()) + assert.Equal(t, "unknown key_op value 42", KeyOp(42).String()) +} + +func Test_KeyType(t *testing.T) { + + tvs := []struct { + Name string + Value KeyType + }{ + {"OKP", KeyTypeOkp}, + {"EC2", KeyTypeEc2}, + {"Symmetric", KeyTypeSymmetric}, + } + + for _, tv := range tvs { + assert.Equal(t, tv.Name, tv.Value.String()) + + data, err := cbor.Marshal(tv.Name) + require.NoError(t, err) + + var ko KeyType + err = cbor.Unmarshal(data, &ko) + require.NoError(t, err) + assert.Equal(t, tv.Value, ko) + + data, err = cbor.Marshal(int(tv.Value)) + require.NoError(t, err) + + err = cbor.Unmarshal(data, &ko) + require.NoError(t, err) + assert.Equal(t, tv.Value, ko) + + } + + var ko KeyType + + data := []byte{0x20} + err := ko.UnmarshalCBOR(data) + assert.EqualError(t, err, "unknown key type value -1") + + data = []byte{0x00} + err = ko.UnmarshalCBOR(data) + assert.EqualError(t, err, "invalid key type value 0") + + data = []byte{0x03} + err = ko.UnmarshalCBOR(data) + assert.EqualError(t, err, "unknown key type value 3") + + data = []byte{0x63, 0x66, 0x6f, 0x6f} + err = ko.UnmarshalCBOR(data) + assert.EqualError(t, err, "unknown key type value \"foo\"") + + data = []byte{0x40} + err = ko.UnmarshalCBOR(data) + assert.EqualError(t, err, "invalid key type value: must be int or string, found []uint8") +} + +func Test_Curve(t *testing.T) { + + tvs := []struct { + Name string + Value Curve + }{ + {"P-256", CurveP256}, + {"P-384", CurveP384}, + {"P-521", CurveP521}, + {"X25519", CurveX25519}, + {"X448", CurveX448}, + {"Ed25519", CurveEd25519}, + {"Ed448", CurveEd448}, + } + + for _, tv := range tvs { + assert.Equal(t, tv.Name, tv.Value.String()) + + data, err := cbor.Marshal(tv.Name) + require.NoError(t, err) + + var c Curve + err = cbor.Unmarshal(data, &c) + require.NoError(t, err) + assert.Equal(t, tv.Value, c) + + data, err = cbor.Marshal(int(tv.Value)) + require.NoError(t, err) + + err = cbor.Unmarshal(data, &c) + require.NoError(t, err) + assert.Equal(t, tv.Value, c) + + } + + var c Curve + + data := []byte{0x20} + err := c.UnmarshalCBOR(data) + assert.EqualError(t, err, "unknown curve value -1") + + data = []byte{0x00} + err = c.UnmarshalCBOR(data) + assert.EqualError(t, err, "unknown curve value 0") + + data = []byte{0x63, 0x66, 0x6f, 0x6f} + err = c.UnmarshalCBOR(data) + assert.EqualError(t, err, "unknown curve value \"foo\"") + + data = []byte{0x40} + err = c.UnmarshalCBOR(data) + assert.EqualError(t, err, "invalid curve value: must be int or string, found []uint8") + + assert.Equal(t, "unknown curve value 42", Curve(42).String()) +} + +func Test_Key_UnmarshalCBOR(t *testing.T) { + tvs := []struct { + Name string + Value []byte + WantErr string + Validate func(k *Key) + }{ + { + Name: "ok OKP", + Value: []byte{ + 0xa5, // map (5) + 0x01, 0x01, // kty: OKP + 0x03, 0x27, // alg: EdDSA w/ Ed25519 + 0x04, // key ops + 0x81, // array (1) + 0x02, // verify + 0x20, 0x06, // curve: Ed25519 + 0x21, 0x58, 0x20, // x-coordinate: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + WantErr: "", + Validate: func(k *Key) { + assert.Equal(t, KeyTypeOkp, k.KeyType) + assert.Equal(t, AlgorithmEd25519, k.Algorithm) + assert.Equal(t, CurveEd25519, k.Curve) + assert.Equal(t, []KeyOp{KeyOpVerify}, k.KeyOps) + assert.Equal(t, []byte{ + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + k.X, + ) + assert.Equal(t, []byte(nil), k.K) + }, + }, + { + Name: "invalid key type", + Value: []byte{ + 0xa1, // map (2) + 0x01, 0x00, // kty: invalid + }, + WantErr: "invalid key type value 0", + Validate: nil, + }, + { + Name: "missing curve OKP", + Value: []byte{ + 0xa1, // map (2) + 0x01, 0x01, // kty: OKP + }, + WantErr: "missing Curve parameter (required for OKP key type)", + Validate: nil, + }, + { + Name: "missing curve EC2", + Value: []byte{ + 0xa1, // map (2) + 0x01, 0x02, // kty: EC2 + }, + WantErr: "missing Curve parameter (required for EC2 key type)", + Validate: nil, + }, + { + Name: "invalid curve OKP", + Value: []byte{ + 0xa2, // map (2) + 0x01, 0x01, // kty: OKP + 0x20, 0x01, // curve: CurveP256 + }, + WantErr: "OKP curve must be X25519, X448, Ed25519, or Ed448; found \"P-256\"", + Validate: nil, + }, + { + Name: "invalid curve EC2", + Value: []byte{ + 0xa2, // map (2) + 0x01, 0x02, // kty: EC2 + 0x20, 0x06, // curve: CurveEd25519 + }, + WantErr: "EC2 curve must be P-256, P-384, or P-521; found \"Ed25519\"", + Validate: nil, + }, + { + Name: "ok Symmetric", + Value: []byte{ + 0xa4, // map (4) + 0x01, 0x04, // kty: Symmetric + 0x03, 0x38, 0x24, // alg: PS256 + 0x04, // key ops + 0x81, // array (1) + 0x02, // verify + 0x20, 0x58, 0x20, // k: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + WantErr: "", + Validate: func(k *Key) { + assert.Equal(t, KeyTypeSymmetric, k.KeyType) + assert.Equal(t, AlgorithmPS256, k.Algorithm) + assert.EqualValues(t, 0, k.Curve) + assert.Equal(t, []KeyOp{KeyOpVerify}, k.KeyOps) + assert.Equal(t, []byte{ + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + k.K, + ) + }, + }, + { + Name: "missing K", + Value: []byte{ + 0xa1, // map (1) + 0x01, 0x04, // kty: Symmetric + }, + WantErr: "missing K parameter (required for Symmetric key type)", + Validate: nil, + }, + { + Name: "wrong algorithm", + Value: []byte{ + 0xa4, // map (3) + 0x01, 0x01, // kty: OKP + 0x03, 0x26, // alg: ECDSA w/ SHA-256 + 0x20, 0x06, // curve: Ed25519 + 0x21, 0x58, 0x20, // x-coordinate: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + WantErr: "found algorithm \"ES256\" (expected \"EdDSA\")", + Validate: nil, + }, + } + + for _, tv := range tvs { + t.Run(tv.Name, func(t *testing.T) { + var k Key + + err := k.UnmarshalCBOR(tv.Value) + if tv.WantErr != "" { + assert.EqualError(t, err, tv.WantErr) + } else { + tv.Validate(&k) + } + }) + } +} + +func Test_Key_MarshalCBOR(t *testing.T) { + k := Key{ + KeyType: KeyTypeOkp, + KeyOps: []KeyOp{KeyOpVerify, KeyOpEncrypt}, + X: []byte{ + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + Algorithm: AlgorithmEd25519, + Curve: CurveEd25519, + } + + data, err := k.MarshalCBOR() + require.NoError(t, err) + assert.Equal(t, + []byte{ + 0xa5, // map (5) + 0x01, 0x01, // kty: OKP + 0x03, 0x27, // alg: EdDSA w/ Ed25519 + 0x04, // key ops + 0x82, // array (2) + 0x02, 0x03, // verify, encrypt + 0x20, 0x06, // curve: Ed25519 + 0x21, 0x58, 0x20, // x-coordinate: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + data, + ) + + k = Key{ + KeyType: KeyTypeSymmetric, + K: []byte{ + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + } + + data, err = k.MarshalCBOR() + require.NoError(t, err) + assert.Equal(t, + []byte{ + 0xa2, // map (2) + 0x01, 0x04, // kty: Symmetric + 0x20, 0x58, 0x20, // K: bytes(32) + 0x15, 0x52, 0x2e, 0xf1, 0x57, 0x29, 0xcc, 0xf3, // 32-byte value + 0x95, 0x09, 0xea, 0x5c, 0x15, 0xa2, 0x6b, 0xe9, + 0x49, 0xe3, 0x88, 0x07, 0xa5, 0xc2, 0x6e, 0xf9, + 0x28, 0x14, 0x87, 0xef, 0x4a, 0xe6, 0x7b, 0x46, + }, + data, + ) + + k.KeyType = KeyType(42) + _, err = k.MarshalCBOR() + assert.EqualError(t, err, "invalid key type: \"unknown key type value 42\"") +} + +func Test_Key_Create_and_Validate(t *testing.T) { + x := []byte{ + 0x30, 0xa0, 0x42, 0x4c, 0xd2, 0x1c, 0x29, 0x44, + 0x83, 0x8a, 0x2d, 0x75, 0xc9, 0x2b, 0x37, 0xe7, + 0x6e, 0xa2, 0x0d, 0x9f, 0x00, 0x89, 0x3a, 0x3b, + 0x4e, 0xee, 0x8a, 0x3c, 0x0a, 0xaf, 0xec, 0x3e, + } + + y := []byte{ + 0xe0, 0x4b, 0x65, 0xe9, 0x24, 0x56, 0xd9, 0x88, + 0x8b, 0x52, 0xb3, 0x79, 0xbd, 0xfb, 0xd5, 0x1e, + 0xe8, 0x69, 0xef, 0x1f, 0x0f, 0xc6, 0x5b, 0x66, + 0x59, 0x69, 0x5b, 0x6c, 0xce, 0x08, 0x17, 0x23, + } + + key, err := NewOkpKey(AlgorithmEd25519, x, nil) + require.NoError(t, err) + assert.Equal(t, KeyTypeOkp, key.KeyType) + assert.Equal(t, x, key.X) + + _, err = NewOkpKey(AlgorithmES256, x, nil) + assert.EqualError(t, err, "unsupported algorithm \"ES256\"") + + _, err = NewEc2Key(AlgorithmEd25519, x, y, nil) + assert.EqualError(t, err, "unsupported algorithm \"EdDSA\"") + + key, err = NewEc2Key(AlgorithmES256, x, y, nil) + require.NoError(t, err) + assert.Equal(t, KeyTypeEc2, key.KeyType) + assert.Equal(t, x, key.X) + assert.Equal(t, y, key.Y) + + key, err = NewSymmetricKey(x) + require.NoError(t, err) + assert.Equal(t, x, key.K) + + key.KeyType = KeyType(7) + err = key.Validate() + assert.EqualError(t, err, "unknown key type value 7") + + _, err = NewKeyFromPublic(AlgorithmES256, + crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) + assert.EqualError(t, err, "ES256: invalid public key") + + _, err = NewKeyFromPublic(AlgorithmEd25519, + crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) + assert.EqualError(t, err, "EdDSA: invalid public key") + + _, err = NewKeyFromPublic(AlgorithmInvalid, + crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) + assert.EqualError(t, err, "algorithm not supported") + + _, err = NewKeyFromPrivate(AlgorithmES256, + crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) + assert.EqualError(t, err, "ES256: invalid private key") + + _, err = NewKeyFromPrivate(AlgorithmEd25519, + crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) + assert.EqualError(t, err, "EdDSA: invalid private key") + + _, err = NewKeyFromPrivate(AlgorithmInvalid, + crypto.PublicKey([]byte{0xde, 0xad, 0xbe, 0xef})) + assert.EqualError(t, err, "algorithm not supported") +} + +func Test_Key_ed25519_signature_round_trip(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + key, err := NewKeyFromPrivate(AlgorithmEd25519, priv) + require.NoError(t, err) + assert.Equal(t, AlgorithmEd25519, key.Algorithm) + assert.Equal(t, CurveEd25519, key.Curve) + assert.EqualValues(t, pub, key.X) + assert.EqualValues(t, priv[:32], key.D) + + signer, err := key.Signer() + require.NoError(t, err) + + message := []byte("foo bar") + sig, err := signer.Sign(rand.Reader, message) + require.NoError(t, err) + + key, err = NewKeyFromPublic(AlgorithmEd25519, pub) + require.NoError(t, err) + + assert.Equal(t, AlgorithmEd25519, key.Algorithm) + assert.Equal(t, CurveEd25519, key.Curve) + assert.EqualValues(t, pub, key.X) + + verifier, err := key.Verifier() + require.NoError(t, err) + + err = verifier.Verify(message, sig) + assert.NoError(t, err) +} + +func Test_Key_ecdsa_signature_round_trip(t *testing.T) { + for _, tv := range []struct { + EC elliptic.Curve + Curve Curve + Algorithm Algorithm + }{ + {elliptic.P256(), CurveP256, AlgorithmES256}, + {elliptic.P384(), CurveP384, AlgorithmES384}, + {elliptic.P521(), CurveP521, AlgorithmES512}, + } { + t.Run(tv.Curve.String(), func(t *testing.T) { + priv, err := ecdsa.GenerateKey(tv.EC, rand.Reader) + require.NoError(t, err) + + key, err := NewKeyFromPrivate(tv.Algorithm, priv) + require.NoError(t, err) + assert.Equal(t, tv.Algorithm, key.Algorithm) + assert.Equal(t, tv.Curve, key.Curve) + assert.EqualValues(t, priv.X.Bytes(), key.X) + assert.EqualValues(t, priv.Y.Bytes(), key.Y) + assert.EqualValues(t, priv.D.Bytes(), key.D) + + signer, err := key.Signer() + require.NoError(t, err) + + message := []byte("foo bar") + sig, err := signer.Sign(rand.Reader, message) + require.NoError(t, err) + + pub := priv.Public() + + key, err = NewKeyFromPublic(tv.Algorithm, pub) + require.NoError(t, err) + + assert.Equal(t, tv.Algorithm, key.Algorithm) + assert.Equal(t, tv.Curve, key.Curve) + assert.EqualValues(t, priv.X.Bytes(), key.X) + assert.EqualValues(t, priv.Y.Bytes(), key.Y) + + verifier, err := key.Verifier() + require.NoError(t, err) + + err = verifier.Verify(message, sig) + assert.NoError(t, err) + }) + } +} + +func Test_Key_derive_algorithm(t *testing.T) { + k := Key{ + KeyType: KeyTypeOkp, + Curve: CurveX448, + } + + _, err := k.GetAlgorithm() + assert.EqualError(t, err, "unsupported curve \"X448\"") + + k = Key{ + KeyType: KeyTypeOkp, + Curve: CurveEd25519, + } + + alg, err := k.GetAlgorithm() + require.NoError(t, err) + assert.Equal(t, AlgorithmEd25519, alg) +} + +func Test_Key_signer_validation(t *testing.T) { + pub, priv, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + key, err := NewKeyFromPublic(AlgorithmEd25519, pub) + require.NoError(t, err) + + _, err = key.Signer() + require.Equal(t, err, ErrNotPrivKey) + + key, err = NewKeyFromPrivate(AlgorithmEd25519, priv) + require.NoError(t, err) + + key.KeyType = KeyTypeEc2 + _, err = key.Signer() + require.EqualError(t, err, "EC2 curve must be P-256, P-384, or P-521; found \"Ed25519\"") + + key.Curve = CurveP256 + _, err = key.Signer() + require.EqualError(t, err, "found algorithm \"EdDSA\" (expected \"ES256\")") + + key.KeyType = KeyTypeOkp + key.Algorithm = AlgorithmEd25519 + key.Curve = CurveEd25519 + key.KeyOps = []KeyOp{} + _, err = key.Signer() + require.Equal(t, err, ErrSignOpNotSupported) + + key.KeyOps = []KeyOp{KeyOpSign} + _, err = key.Signer() + require.NoError(t, err) + + key.Algorithm = AlgorithmES256 + _, err = key.Signer() + require.EqualError(t, err, "found algorithm \"ES256\" (expected \"EdDSA\")") + + key.Curve = CurveX448 + _, err = key.Signer() + assert.EqualError(t, err, "unsupported curve \"X448\"") + +} + +func Test_Key_verifier_validation(t *testing.T) { + pub, _, err := ed25519.GenerateKey(rand.Reader) + require.NoError(t, err) + + key, err := NewKeyFromPublic(AlgorithmEd25519, pub) + require.NoError(t, err) + + _, err = key.Verifier() + require.NoError(t, err) + + key.KeyType = KeyTypeEc2 + _, err = key.Verifier() + require.EqualError(t, err, "EC2 curve must be P-256, P-384, or P-521; found \"Ed25519\"") + + key.KeyType = KeyTypeOkp + key.KeyOps = []KeyOp{} + _, err = key.Verifier() + require.Equal(t, err, ErrVerifyOpNotSupported) + + key.KeyOps = []KeyOp{KeyOpVerify} + _, err = key.Verifier() + require.NoError(t, err) +} + +func Test_Key_crypto_keys(t *testing.T) { + k := Key{ + KeyType: KeyType(7), + } + + _, err := k.PublicKey() + assert.EqualError(t, err, "unexpected key type \"unknown key type value 7\"") + _, err = k.PrivateKey() + assert.EqualError(t, err, "unexpected key type \"unknown key type value 7\"") + + k = Key{ + KeyType: KeyTypeOkp, + Curve: CurveX448, + } + + _, err = k.PublicKey() + assert.EqualError(t, err, "unsupported curve \"X448\"") + _, err = k.PrivateKey() + assert.EqualError(t, err, "unsupported curve \"X448\"") +}