Skip to content

Commit

Permalink
Merge pull request #101 from shizhMSFT/content
Browse files Browse the repository at this point in the history
refactor!: accept message content instead of digest for sign and verify
  • Loading branch information
SteveLasker authored Aug 24, 2022
2 parents 66b4a5a + aa2f81c commit 9d2fab6
Show file tree
Hide file tree
Showing 18 changed files with 340 additions and 391 deletions.
122 changes: 18 additions & 104 deletions algorithm.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ package cose

import (
"crypto"
"hash"
"strconv"
"sync"
)

// Algorithms supported by this library.
Expand Down Expand Up @@ -43,36 +41,13 @@ const (
// Algorithm represents an IANA algorithm entry in the COSE Algorithms registry.
// Algorithms with string values are not supported.
//
// See Also
// # See Also
//
// COSE Algorithms: https://www.iana.org/assignments/cose/cose.xhtml#algorithms
//
// RFC 8152 16.4: https://datatracker.ietf.org/doc/html/rfc8152#section-16.4
type Algorithm int64

// extAlgorithm describes an extended algorithm, which is not implemented this
// library.
type extAlgorithm struct {
// Name of the algorithm.
Name string

// Hash is the hash algorithm associated with the algorithm.
// If HashFunc is present, Hash is ignored.
// If HashFunc is not present and Hash is set to 0, no hash is used.
Hash crypto.Hash

// HashFunc is the hash algorithm associated with the algorithm.
// HashFunc is preferred in the case that the hash algorithm is not
// supported by the golang built-in crypto hashes.
// For regular scenarios, use Hash instead.
HashFunc func() hash.Hash
}

var (
extAlgorithms map[Algorithm]extAlgorithm
extMu sync.RWMutex
)

// String returns the name of the algorithm
func (a Algorithm) String() string {
switch a {
Expand All @@ -92,100 +67,39 @@ func (a Algorithm) String() string {
// As stated in RFC 8152 8.2, only the pure EdDSA version is used for
// COSE.
return "EdDSA"
default:
return "unknown algorithm value " + strconv.Itoa(int(a))
}
extMu.RLock()
alg, ok := extAlgorithms[a]
extMu.RUnlock()
if ok {
return alg.Name
}
return "unknown algorithm value " + strconv.Itoa(int(a))
}

// hashFunc returns the hash associated with the algorithm supported by this
// library.
func (a Algorithm) hashFunc() (crypto.Hash, bool) {
func (a Algorithm) hashFunc() crypto.Hash {
switch a {
case AlgorithmPS256, AlgorithmES256:
return crypto.SHA256, true
return crypto.SHA256
case AlgorithmPS384, AlgorithmES384:
return crypto.SHA384, true
return crypto.SHA384
case AlgorithmPS512, AlgorithmES512:
return crypto.SHA512, true
case AlgorithmEd25519:
return 0, true
}
return 0, false
}

// newHash returns a new hash instance for computing the digest specified in the
// algorithm.
// Returns nil if no hash is required for the message.
func (a Algorithm) newHash() (hash.Hash, error) {
h, ok := a.hashFunc()
if !ok {
extMu.RLock()
alg, ok := extAlgorithms[a]
extMu.RUnlock()
if !ok {
return nil, ErrUnknownAlgorithm
}
if alg.HashFunc != nil {
return alg.HashFunc(), nil
}
h = alg.Hash
}
if h == 0 {
// no hash required
return nil, nil
}
if h.Available() {
return h.New(), nil
return crypto.SHA512
default:
return 0
}
return nil, ErrUnavailableHashFunc
}

// computeHash computes the digest using the hash specified in the algorithm.
// Returns the input data if no hash is required for the message.
func (a Algorithm) computeHash(data []byte) ([]byte, error) {
h, err := a.newHash()
if err != nil {
return nil, err
}
if h == nil {
return data, nil
}
if _, err := h.Write(data); err != nil {
return nil, err
}
return h.Sum(nil), nil
return computeHash(a.hashFunc(), data)
}

// RegisterAlgorithm provides extensibility for the COSE library to support
// private algorithms or algorithms not yet registered in IANA.
// The existing algorithms cannot be re-registered.
// The parameter `hash` is the hash algorithm associated with the algorithm. If
// hashFunc is present, hash is ignored. If hashFunc is not present and hash is
// set to 0, no hash is used for this algorithm.
// The parameter `hashFunc` is preferred in the case that the hash algorithm is not
// supported by the golang built-in crypto hashes.
// It is safe for concurrent use by multiple goroutines.
func RegisterAlgorithm(alg Algorithm, name string, hash crypto.Hash, hashFunc func() hash.Hash) error {
if _, ok := alg.hashFunc(); ok {
return ErrAlgorithmRegistered
}
extMu.Lock()
defer extMu.Unlock()
if _, ok := extAlgorithms[alg]; ok {
return ErrAlgorithmRegistered
// computeHash computes the digest using the given hash.
func computeHash(h crypto.Hash, data []byte) ([]byte, error) {
if !h.Available() {
return nil, ErrUnavailableHashFunc
}
if extAlgorithms == nil {
extAlgorithms = make(map[Algorithm]extAlgorithm)
}
extAlgorithms[alg] = extAlgorithm{
Name: name,
Hash: hash,
HashFunc: hashFunc,
hh := h.New()
if _, err := hh.Write(data); err != nil {
return nil, err
}
return nil
return hh.Sum(nil), nil
}
143 changes: 29 additions & 114 deletions algorithm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,13 @@ package cose
import (
"crypto"
"crypto/sha256"
"hash"
"io"
"reflect"
"sync"
"testing"
)

// resetExtendedAlgorithm cleans up extended algorithms
func resetExtendedAlgorithm() {
extMu.Lock()
extAlgorithms = nil
extMu.Unlock()
}

func TestAlgorithm_String(t *testing.T) {
defer resetExtendedAlgorithm()

// register extended algorithms
algFoo := Algorithm(-102)
if err := RegisterAlgorithm(algFoo, "foo", 0, nil); err != nil {
t.Fatalf("RegisterAlgorithm() = %v", err)
}
algBar := Algorithm(-98)
if err := RegisterAlgorithm(algBar, "bar", 0, nil); err != nil {
t.Fatalf("RegisterAlgorithm() = %v", err)
}

// run tests
tests := []struct {
name string
Expand Down Expand Up @@ -69,16 +51,6 @@ func TestAlgorithm_String(t *testing.T) {
alg: AlgorithmEd25519,
want: "EdDSA",
},
{
name: "extended algorithm: foo",
alg: algFoo,
want: "foo",
},
{
name: "extended algorithm: bar",
alg: algBar,
want: "bar",
},
{
name: "unknown algorithm",
alg: 0,
Expand All @@ -95,26 +67,6 @@ func TestAlgorithm_String(t *testing.T) {
}

func TestAlgorithm_computeHash(t *testing.T) {
defer resetExtendedAlgorithm()

// register extended algorithms
algFoo := Algorithm(-102)
if err := RegisterAlgorithm(algFoo, "foo", crypto.SHA256, nil); err != nil {
t.Fatalf("RegisterAlgorithm() = %v", err)
}
algBar := Algorithm(-98)
if err := RegisterAlgorithm(algBar, "bar", 0, sha256.New); err != nil {
t.Fatalf("RegisterAlgorithm() = %v", err)
}
algPlain := Algorithm(-112)
if err := RegisterAlgorithm(algPlain, "plain", 0, nil); err != nil {
t.Fatalf("RegisterAlgorithm() = %v", err)
}
algUnavailableHash := Algorithm(-117)
if err := RegisterAlgorithm(algUnavailableHash, "unknown hash", 42, nil); err != nil {
t.Fatalf("RegisterAlgorithm() = %v", err)
}

// run tests
data := []byte("hello world")
tests := []struct {
Expand Down Expand Up @@ -178,39 +130,13 @@ func TestAlgorithm_computeHash(t *testing.T) {
},
},
{
name: "Ed25519",
alg: AlgorithmEd25519,
want: data,
},
{
name: "extended algorithm with crypto.Hash",
alg: algFoo,
want: []byte{
0xb9, 0x4d, 0x27, 0xb9, 0x93, 0x4d, 0x3e, 0x08, 0xa5, 0x2e, 0x52, 0xd7, 0xda, 0x7d, 0xab, 0xfa,
0xc4, 0x84, 0xef, 0xe3, 0x7a, 0x53, 0x80, 0xee, 0x90, 0x88, 0xf7, 0xac, 0xe2, 0xef, 0xcd, 0xe9,
},
},
{
name: "extended algorithm with hashFunc",
alg: algBar,
want: []byte{
0xb9, 0x4d, 0x27, 0xb9, 0x93, 0x4d, 0x3e, 0x08, 0xa5, 0x2e, 0x52, 0xd7, 0xda, 0x7d, 0xab, 0xfa,
0xc4, 0x84, 0xef, 0xe3, 0x7a, 0x53, 0x80, 0xee, 0x90, 0x88, 0xf7, 0xac, 0xe2, 0xef, 0xcd, 0xe9,
},
},
{
name: "extended algorithm without hash",
alg: algPlain,
want: data,
name: "Ed25519",
alg: AlgorithmEd25519,
wantErr: ErrUnavailableHashFunc,
},
{
name: "unknown algorithm",
alg: 0,
wantErr: ErrUnknownAlgorithm,
},
{
name: "unknown hash",
alg: algUnavailableHash,
wantErr: ErrUnavailableHashFunc,
},
}
Expand All @@ -228,46 +154,35 @@ func TestAlgorithm_computeHash(t *testing.T) {
}
}

func TestRegisterAlgorithm(t *testing.T) {
defer resetExtendedAlgorithm()
type badHash struct{}

// register existing algorithm (should fail)
if err := RegisterAlgorithm(AlgorithmES256, "ES256", crypto.SHA256, nil); err != ErrAlgorithmRegistered {
t.Errorf("RegisterAlgorithm() error = %v, wantErr %v", err, ErrAlgorithmRegistered)
}
func badHashNew() hash.Hash {
return &badHash{}
}

algFoo := Algorithm(-102)
// register external algorithm
if err := RegisterAlgorithm(algFoo, "foo", 0, nil); err != nil {
t.Errorf("RegisterAlgorithm() error = %v, wantErr %v", err, false)
}
func (*badHash) Write(p []byte) (n int, err error) {
return 0, io.EOF
}

// double register external algorithm (should fail)
if err := RegisterAlgorithm(algFoo, "foo", 0, nil); err != ErrAlgorithmRegistered {
t.Errorf("RegisterAlgorithm() error = %v, wantErr %v", err, ErrAlgorithmRegistered)
}
func (*badHash) Sum(b []byte) []byte {
return b
}

func TestRegisterAlgorithm_Concurrent(t *testing.T) {
defer resetExtendedAlgorithm()
func (*badHash) Reset() {}

// Register algorithms concurrently to ensure testing on race mode catches races.
var wg sync.WaitGroup
wg.Add(2)
func (*badHash) Size() int {
return 0
}
func (*badHash) BlockSize() int {
return 0
}

go func() {
defer wg.Done()
// register existing algorithm (should fail)
if err := RegisterAlgorithm(AlgorithmES256, "ES256", crypto.SHA256, nil); err != ErrAlgorithmRegistered {
t.Errorf("RegisterAlgorithm() error = %v, wantErr %v", err, ErrAlgorithmRegistered)
}
}()
go func() {
defer wg.Done()
// register external algorithm
if err := RegisterAlgorithm(Algorithm(-102), "foo", 0, nil); err != nil {
t.Errorf("RegisterAlgorithm() error = %v, wantErr %v", err, false)
}
}()
wg.Wait()
func Test_computeHash(t *testing.T) {
crypto.RegisterHash(crypto.SHA256, badHashNew)
defer crypto.RegisterHash(crypto.SHA256, sha256.New)

_, err := computeHash(crypto.SHA256, nil)
if err != io.EOF {
t.Fatalf("computeHash() error = %v, wantErr %v", err, io.EOF)
}
}
2 changes: 1 addition & 1 deletion bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ func newSign1Message() *cose.Sign1Message {
cose.HeaderLabelAlgorithm: cose.AlgorithmES256,
},
Unprotected: cose.UnprotectedHeader{
cose.HeaderLabelKeyID: 1,
cose.HeaderLabelKeyID: []byte{0x01},
},
},
Payload: make([]byte, 100),
Expand Down
Loading

0 comments on commit 9d2fab6

Please sign in to comment.