Skip to content

Commit

Permalink
Merge pull request #66 from qmuntal/iv
Browse files Browse the repository at this point in the history
Ensure IV and Partial IV are not both present
  • Loading branch information
yogeshbdeshpande authored May 20, 2022
2 parents cb68acd + 6042760 commit 2fccfc9
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 15 deletions.
66 changes: 66 additions & 0 deletions headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ const (
HeaderLabelCritical int64 = 2
HeaderLabelContentType int64 = 3
HeaderLabelKeyID int64 = 4
HeaderLabelIV int64 = 5
HeaderLabelPartialIV int64 = 6
HeaderLabelCounterSignature int64 = 7
HeaderLabelCounterSignature0 int64 = 9
HeaderLabelX5Bag int64 = 32
Expand Down Expand Up @@ -43,6 +45,9 @@ func (h ProtectedHeader) MarshalCBOR() ([]byte, error) {
if err = h.ensureCritical(); err != nil {
return nil, err
}
if err = ensureHeaderIV(h); err != nil {
return nil, fmt.Errorf("protected header: %w", err)
}
encoded, err = encMode.Marshal(map[interface{}]interface{}(h))
if err != nil {
return nil, err
Expand Down Expand Up @@ -83,6 +88,9 @@ func (h *ProtectedHeader) UnmarshalCBOR(data []byte) error {
if err := candidate.ensureCritical(); err != nil {
return err
}
if err := ensureHeaderIV(candidate); err != nil {
return fmt.Errorf("protected header: %w", err)
}

// cast to type Algorithm if `alg` presents
if alg, err := candidate.Algorithm(); err == nil {
Expand Down Expand Up @@ -170,6 +178,9 @@ func (h UnprotectedHeader) MarshalCBOR() ([]byte, error) {
if err := validateHeaderLabel(h); err != nil {
return nil, err
}
if err := ensureHeaderIV(h); err != nil {
return nil, fmt.Errorf("unprotected header: %w", err)
}
return encMode.Marshal(map[interface{}]interface{}(h))
}

Expand All @@ -196,6 +207,9 @@ func (h *UnprotectedHeader) UnmarshalCBOR(data []byte) error {
if err := decMode.Unmarshal(data, &header); err != nil {
return err
}
if err := ensureHeaderIV(header); err != nil {
return fmt.Errorf("unprotected header: %w", err)
}
*h = header
return nil
}
Expand Down Expand Up @@ -253,6 +267,23 @@ type Headers struct {
Unprotected UnprotectedHeader
}

// marshal encoded both headers.
// It returns RawProtected and RawUnprotected if those are set.
func (h *Headers) marshal() (cbor.RawMessage, cbor.RawMessage, error) {
if err := h.ensureIV(); err != nil {
return nil, nil, err
}
protected, err := h.MarshalProtected()
if err != nil {
return nil, nil, err
}
unprotected, err := h.MarshalUnprotected()
if err != nil {
return nil, nil, err
}
return protected, unprotected, nil
}

// MarshalProtected encodes the protected header.
// RawProtected is returned if it is not set to nil.
func (h *Headers) MarshalProtected() ([]byte, error) {
Expand Down Expand Up @@ -280,6 +311,9 @@ func (h *Headers) UnmarshalFromRaw() error {
if err := decMode.Unmarshal(h.RawUnprotected, &h.Unprotected); err != nil {
return fmt.Errorf("cbor: invalid unprotected header: %w", err)
}
if err := h.ensureIV(); err != nil {
return err
}
return nil
}

Expand Down Expand Up @@ -331,6 +365,38 @@ func (h *Headers) ensureVerificationAlgorithm(alg Algorithm, external []byte) er
return err
}

// ensureIV ensures IV and Partial IV are not both present
// in the protected and unprotected headers.
// It does not check if they are both present within one header,
// as it will be checked later on.
//
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
func (h *Headers) ensureIV() error {
if hasLabel(h.Protected, HeaderLabelIV) && hasLabel(h.Unprotected, HeaderLabelPartialIV) {
return errors.New("IV (protected) and PartialIV (unprotected) parameters must not both be present")
}
if hasLabel(h.Protected, HeaderLabelPartialIV) && hasLabel(h.Unprotected, HeaderLabelIV) {
return errors.New("IV (unprotected) and PartialIV (protected) parameters must not both be present")
}
return nil
}

// hasLabel returns true if h contains label.
func hasLabel(h map[interface{}]interface{}, label interface{}) bool {
_, ok := h[label]
return ok
}

// ensureHeaderIV ensures IV and Partial IV are not both present in the header.
//
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
func ensureHeaderIV(h map[interface{}]interface{}) error {
if hasLabel(h, HeaderLabelIV) && hasLabel(h, HeaderLabelPartialIV) {
return errors.New("IV and PartialIV parameters must not both be present")
}
return nil
}

// validateHeaderLabel validates if all header labels are integers or strings.
//
// label = int / tstr
Expand Down
30 changes: 30 additions & 0 deletions headers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ func TestProtectedHeader_MarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "iv and partial iv present",
h: ProtectedHeader{
HeaderLabelIV: "foo",
HeaderLabelPartialIV: "bar",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -265,6 +273,13 @@ func TestProtectedHeader_UnmarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "iv and partial iv present",
data: []byte{
0x4b, 0xa2, 0x5, 0x63, 0x66, 0x6f, 0x6f, 0x6, 0x63, 0x62, 0x61, 0x72,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -517,6 +532,14 @@ func TestUnprotectedHeader_MarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "iv and partial iv present",
h: UnprotectedHeader{
HeaderLabelIV: "foo",
HeaderLabelPartialIV: "bar",
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -617,6 +640,13 @@ func TestUnprotectedHeader_UnmarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "iv and partial iv present",
data: []byte{
0xa2, 0x5, 0x63, 0x66, 0x6f, 0x6f, 0x6, 0x63, 0x62, 0x61, 0x72,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
12 changes: 2 additions & 10 deletions sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,7 @@ func (s *Signature) MarshalCBOR() ([]byte, error) {
if len(s.Signature) == 0 {
return nil, ErrEmptySignature
}
protected, err := s.Headers.MarshalProtected()
if err != nil {
return nil, err
}
unprotected, err := s.Headers.MarshalUnprotected()
protected, unprotected, err := s.Headers.marshal()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -329,11 +325,7 @@ func (m *SignMessage) MarshalCBOR() ([]byte, error) {
if len(m.Signatures) == 0 {
return nil, ErrNoSignatures
}
protected, err := m.Headers.MarshalProtected()
if err != nil {
return nil, err
}
unprotected, err := m.Headers.MarshalUnprotected()
protected, unprotected, err := m.Headers.marshal()
if err != nil {
return nil, err
}
Expand Down
6 changes: 1 addition & 5 deletions sign1.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,7 @@ func (m *Sign1Message) MarshalCBOR() ([]byte, error) {
if len(m.Signature) == 0 {
return nil, ErrEmptySignature
}
protected, err := m.Headers.MarshalProtected()
if err != nil {
return nil, err
}
unprotected, err := m.Headers.MarshalUnprotected()
protected, unprotected, err := m.Headers.marshal()
if err != nil {
return nil, err
}
Expand Down
58 changes: 58 additions & 0 deletions sign1_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,40 @@ func TestSign1Message_MarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "protected has IV and unprotected has PartialIV error",
m: &Sign1Message{
Headers: Headers{
Protected: ProtectedHeader{
HeaderLabelAlgorithm: AlgorithmES256,
HeaderLabelIV: "",
},
Unprotected: UnprotectedHeader{
HeaderLabelPartialIV: "",
},
},
Payload: []byte("foo"),
Signature: []byte("bar"),
},
wantErr: true,
},
{
name: "protected has PartialIV and unprotected has IV error",
m: &Sign1Message{
Headers: Headers{
Protected: ProtectedHeader{
HeaderLabelAlgorithm: AlgorithmES256,
HeaderLabelPartialIV: "",
},
Unprotected: UnprotectedHeader{
HeaderLabelIV: "",
},
},
Payload: []byte("foo"),
Signature: []byte("bar"),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -324,6 +358,30 @@ func TestSign1Message_UnmarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "protected has IV and unprotected has PartialIV",
data: []byte{
0xd2, // tag
0x84,
0x46, 0xa1, 0x5, 0x63, 0x66, 0x6f, 0x6f, // protected
0xa1, 0x6, 0x63, 0x62, 0x61, 0x72, // unprotected
0xf6, // payload
0x43, 0x62, 0x61, 0x72, // signature
},
wantErr: true,
},
{
name: "protected has PartialIV and unprotected has IV",
data: []byte{
0xd2, // tag
0x84,
0x46, 0xa1, 0x6, 0x63, 0x66, 0x6f, 0x6f, // protected
0xa1, 0x5, 0x63, 0x62, 0x61, 0x72, // unprotected
0xf6, // payload
0x43, 0x62, 0x61, 0x72, // signature
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down
52 changes: 52 additions & 0 deletions sign_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,38 @@ func TestSignature_MarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "protected has IV and unprotected has PartialIV error",
s: &Signature{
Headers: Headers{
Protected: ProtectedHeader{
HeaderLabelAlgorithm: AlgorithmES256,
HeaderLabelIV: "",
},
Unprotected: UnprotectedHeader{
HeaderLabelPartialIV: "",
},
},
Signature: []byte("bar"),
},
wantErr: true,
},
{
name: "protected has PartialIV and unprotected has IV error",
s: &Signature{
Headers: Headers{
Protected: ProtectedHeader{
HeaderLabelAlgorithm: AlgorithmES256,
HeaderLabelPartialIV: "",
},
Unprotected: UnprotectedHeader{
HeaderLabelIV: "",
},
},
Signature: []byte("bar"),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down Expand Up @@ -227,6 +259,26 @@ func TestSignature_UnmarshalCBOR(t *testing.T) {
},
wantErr: true,
},
{
name: "protected has IV and unprotected has PartialIV",
data: []byte{
0x83,
0x46, 0xa1, 0x5, 0x63, 0x66, 0x6f, 0x6f, // protected
0xa1, 0x6, 0x63, 0x62, 0x61, 0x72, // unprotected
0x43, 0x62, 0x61, 0x72, // signature
},
wantErr: true,
},
{
name: "protected has PartialIV and unprotected has IV",
data: []byte{
0x83,
0x46, 0xa1, 0x6, 0x63, 0x66, 0x6f, 0x6f, // protected
0xa1, 0x5, 0x63, 0x62, 0x61, 0x72, // unprotected
0x43, 0x62, 0x61, 0x72, // signature
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand Down

0 comments on commit 2fccfc9

Please sign in to comment.