Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add gluecose conformance tests #19

Merged
merged 3 commits into from
Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
283 changes: 283 additions & 0 deletions conformance_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
package cose_test

import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
"math/big"
"os"
"path/filepath"
"testing"

"github.com/veraison/go-cose"
)

type TestCase struct {
UUID string `json:"uuid"`
Title string `json:"title"`
Description string `json:"description"`
Key Key `json:"key"`
Alg string `json:"alg"`
Sign1 *Sign1 `json:"sign1::sign"`
Verify1 *Verify1 `json:"sign1::verify"`
}

type Key map[string]string

type Sign1 struct {
Payload string `json:"payload"`
ProtectedHeaders *CBOR `json:"protectedHeaders"`
UnprotectedHeaders *CBOR `json:"unprotectedHeaders"`
External string `json:"external"`
Detached bool `json:"detached"`
TBS CBOR `json:"tbsHex"`
Output CBOR `json:"expectedOutput"`
OutputLength int `json:"fixedOutputLength"`
}

type Verify1 struct {
TaggedCOSESign1 CBOR `json:"taggedCOSESign1"`
External string `json:"external"`
Verify bool `json:"shouldVerify"`
}

type CBOR struct {
CBORHex string `json:"cborHex"`
CBORDiag string `json:"cborDiag"`
}

// Conformance samples are taken from
// https://github.com/gluecose/test-vectors.
var testCases = []string{
"sign1-sign-0000",
"sign1-sign-0001",
"sign1-sign-0002",
"sign1-sign-0003",
"sign1-verify-0000",
"sign1-verify-0001",
"sign1-verify-0002",
"sign1-verify-0003",
}

func TestConformance(t *testing.T) {
for _, name := range testCases {
t.Run(name, func(t *testing.T) {
data, err := os.ReadFile(filepath.Join("testdata", name+".json"))
if err != nil {
t.Fatal(err)
}
var tc TestCase
err = json.Unmarshal(data, &tc)
if err != nil {
t.Fatal(err)
}
processTestCase(t, &tc)
})
}
}

func processTestCase(t *testing.T, tc *TestCase) {
if tc.Sign1 != nil {
testSign1(t, tc)
} else if tc.Verify1 != nil {
testVerify1(t, tc)
} else {
t.Fatal("test case not supported")
}
}

func testVerify1(t *testing.T, tc *TestCase) {
signer, err := getSigner(tc, false)
if err != nil {
t.Fatal(err)
}
var sigMsg cose.Sign1Message
err = sigMsg.UnmarshalCBOR(mustHexToBytes(tc.Verify1.TaggedCOSESign1.CBORHex))
if err != nil {
t.Fatal(err)
}
external := []byte("")
if tc.Verify1.External != "" {
external = mustHexToBytes(tc.Verify1.External)
}
err = sigMsg.Verify(external, *signer.Verifier())
if tc.Verify1.Verify && err != nil {
t.Fatal(err)
} else if !tc.Verify1.Verify && err == nil {
t.Fatal("Verify1 should have failed")
}
}

func testSign1(t *testing.T, tc *TestCase) {
signer, err := getSigner(tc, true)
if err != nil {
t.Fatal(err)
}
sig := tc.Sign1
sigMsg := cose.NewSign1Message()
sigMsg.Payload = mustHexToBytes(sig.Payload)
sigMsg.Headers, err = decodeHeaders(mustHexToBytes(sig.ProtectedHeaders.CBORHex), mustHexToBytes(sig.UnprotectedHeaders.CBORHex))
if err != nil {
t.Fatal(err)
}
external := []byte("")
if sig.External != "" {
external = mustHexToBytes(sig.External)
}
err = sigMsg.Sign(new(zeroSource), external, *signer)
if err != nil {
t.Fatal(err)
}
err = sigMsg.Verify(external, *signer.Verifier())
if err != nil {
t.Fatal(err)
}
got, err := sigMsg.MarshalCBOR()
if err != nil {
t.Fatal(err)
}
want := mustHexToBytes(sig.Output.CBORHex)
if sig.OutputLength > 0 {
got = got[:sig.OutputLength]
want = want[:sig.OutputLength]
}
if !bytes.Equal(want, got) {
t.Fatalf("unexpected output:\nwant: %x\n got: %x", want, got)
}
}

func getSigner(tc *TestCase, private bool) (*cose.Signer, error) {
pkey, err := getKey(tc.Key, private)
if err != nil {
return nil, err
}
alg := mustNameToAlg(tc.Alg)
signer, err := cose.NewSignerFromKey(alg, pkey)
if err != nil {
return nil, err
}
return signer, nil
}

func getKey(key Key, private bool) (crypto.PrivateKey, error) {
switch key["kty"] {
case "EC":
var c elliptic.Curve
switch key["crv"] {
case "P-224":
c = elliptic.P224()
case "P-256":
c = elliptic.P256()
case "P-384":
c = elliptic.P384()
case "P-521":
c = elliptic.P521()
default:
return nil, errors.New("unsupported EC curve: " + key["crv"])
}
pkey := &ecdsa.PrivateKey{
PublicKey: ecdsa.PublicKey{
X: mustBase64ToBigInt(key["x"]),
Y: mustBase64ToBigInt(key["y"]),
Curve: c,
},
}
if private {
pkey.D = mustBase64ToBigInt(key["d"])
}
return pkey, nil
}
return nil, errors.New("unsupported key type: " + key["kty"])
}

// zeroSource is an io.Reader that returns an unlimited number of zero bytes.
type zeroSource struct{}

func (zeroSource) Read(b []byte) (n int, err error) {
for i := range b {
b[i] = 0
}

return len(b), nil
}

func decodeHeaders(protected, unprotected []byte) (*cose.Headers, error) {
var hdr cose.Headers
hdr.Protected = make(map[interface{}]interface{})
hdr.Unprotected = make(map[interface{}]interface{})
err := hdr.DecodeProtected(protected)
if err != nil {
return nil, err
}
b, err := cose.Unmarshal(unprotected)
if err != nil {
return nil, err
}
err = hdr.DecodeUnprotected(b)
if err != nil {
return nil, err
}
hdr.Protected = fixHeader(hdr.Protected)
hdr.Unprotected = fixHeader(hdr.Unprotected)
return &hdr, nil
}

func fixHeader(m map[interface{}]interface{}) map[interface{}]interface{} {
ret := make(map[interface{}]interface{})
for k, v := range m {
switch k1 := k.(type) {
case int64:
k = int(k1)
}
switch v1 := v.(type) {
case int64:
v = int(v1)
}
ret[k] = v
}
return ret
}

func mustHexToInt(s string) int {
return int(mustHexToBigInt(s).Int64())
}

func mustHexToBytes(s string) []byte {
b, err := hex.DecodeString(s)
if err != nil {
panic(err)
}
return b
}

func mustHexToBigInt(s string) *big.Int {
return new(big.Int).SetBytes(mustHexToBytes(s))
}

func mustBase64ToBigInt(s string) *big.Int {
val, err := base64.RawURLEncoding.DecodeString(s)
if err != nil {
panic(err)
}
return new(big.Int).SetBytes(val)
}

// mustNameToAlg returns the algorithm associated to name.
// The content of name is not defined in any RFC,
// but it's what the test cases use to identify algorithms.
func mustNameToAlg(name string) *cose.Algorithm {
switch name {
case "ES256":
return cose.ES256
case "ES384":
return cose.ES384
case "ES512":
return cose.ES512
}
panic("algorithm name not found: " + name)
}
Loading