Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate generic headers #87

Merged
merged 1 commit into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func ExampleSignMessage() {
// create a signature holder
sigHolder := cose.NewSignature()
sigHolder.Headers.Protected.SetAlgorithm(cose.AlgorithmES512)
sigHolder.Headers.Unprotected[cose.HeaderLabelKeyID] = 1
sigHolder.Headers.Unprotected[cose.HeaderLabelKeyID] = []byte("1")

// create message to be signed
msgToSign := cose.NewSignMessage()
Expand Down Expand Up @@ -84,7 +84,7 @@ func ExampleSign1Message() {
msgToSign := cose.NewSign1Message()
msgToSign.Payload = []byte("hello world")
msgToSign.Headers.Protected.SetAlgorithm(cose.AlgorithmES512)
msgToSign.Headers.Unprotected[cose.HeaderLabelKeyID] = 1
msgToSign.Headers.Unprotected[cose.HeaderLabelKeyID] = []byte("1")

// create a signer
privateKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader)
Expand Down Expand Up @@ -157,7 +157,7 @@ func ExampleSign1() {
cose.HeaderLabelAlgorithm: cose.AlgorithmES512,
},
Unprotected: cose.UnprotectedHeader{
cose.HeaderLabelKeyID: 1,
cose.HeaderLabelKeyID: []byte("1"),
},
}
sig, err := cose.Sign1(rand.Reader, signer, headers, []byte("hello world"), nil)
Expand Down
181 changes: 111 additions & 70 deletions headers.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,8 @@ func (h ProtectedHeader) MarshalCBOR() ([]byte, error) {
if len(h) == 0 {
encoded = []byte{}
} else {
err := validateHeaderLabel(h)
err := validateHeaderParameters(h, true)
if err != nil {
return nil, err
}
if err = h.ensureCritical(); err != nil {
return nil, err
}
if err = ensureHeaderIV(h); err != nil {
return nil, fmt.Errorf("protected header: %w", err)
}
encoded, err = encMode.Marshal(map[interface{}]interface{}(h))
Expand Down Expand Up @@ -85,10 +79,7 @@ func (h *ProtectedHeader) UnmarshalCBOR(data []byte) error {
return err
}
candidate := ProtectedHeader(header)
if err := candidate.ensureCritical(); err != nil {
return err
}
if err := ensureHeaderIV(candidate); err != nil {
if err := validateHeaderParameters(candidate, true); err != nil {
return fmt.Errorf("protected header: %w", err)
}

Expand Down Expand Up @@ -140,29 +131,28 @@ func (h ProtectedHeader) Critical() ([]interface{}, error) {
if !ok {
return nil, nil
}
criticalLabels, ok := value.([]interface{})
if !ok {
return nil, errors.New("invalid crit header")
}
// if present, the array MUST have at least one value in it.
if len(criticalLabels) == 0 {
return nil, errors.New("empty crit header")
err := ensureCritical(value, h)
if err != nil {
return nil, err
}
return criticalLabels, nil
return value.([]interface{}), nil
}

// ensureCritical ensures all critical headers are present in the protected bucket.
func (h ProtectedHeader) ensureCritical() error {
labels, err := h.Critical()
if err != nil {
return err
func ensureCritical(value interface{}, headers map[interface{}]interface{}) error {
labels, ok := value.([]interface{})
if !ok {
return errors.New("invalid crit header")
}
// if present, the array MUST have at least one value in it.
if len(labels) == 0 {
return errors.New("empty crit header")
}
for _, label := range labels {
_, ok := normalizeLabel(label)
if !ok {
return fmt.Errorf("critical header label: require int / tstr type, got '%T': %v", label, label)
if !canInt(label) && !canTstr(label) {
return fmt.Errorf("require int / tstr type, got '%T': %v", label, label)
}
if _, ok := h[label]; !ok {
if _, ok := headers[label]; !ok {
return fmt.Errorf("missing critical header: %v", label)
}
}
Expand All @@ -179,13 +169,7 @@ func (h UnprotectedHeader) MarshalCBOR() ([]byte, error) {
if len(h) == 0 {
return []byte{0xa0}, nil
}
if err := validateHeaderLabel(h); err != nil {
return nil, err
}
if err := ensureNoCritical(h); err != nil {
return nil, fmt.Errorf("unprotected header: %w", err)
}
if err := ensureHeaderIV(h); err != nil {
if err := validateHeaderParameters(h, false); err != nil {
return nil, fmt.Errorf("unprotected header: %w", err)
}
return encMode.Marshal(map[interface{}]interface{}(h))
Expand Down Expand Up @@ -214,10 +198,7 @@ func (h *UnprotectedHeader) UnmarshalCBOR(data []byte) error {
if err := decMode.Unmarshal(data, &header); err != nil {
return err
}
if err := ensureNoCritical(header); err != nil {
return fmt.Errorf("unprotected header: %w", err)
}
if err := ensureHeaderIV(header); err != nil {
if err := validateHeaderParameters(header, false); err != nil {
return fmt.Errorf("unprotected header: %w", err)
}
*h = header
Expand Down Expand Up @@ -397,48 +378,108 @@ func hasLabel(h map[interface{}]interface{}, label interface{}) bool {
return ok
}

// ensureHeaderIV ensures IV and Partial IV are not both present in the header.
//
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
func ensureHeaderIV(h map[interface{}]interface{}) error {
if hasLabel(h, HeaderLabelIV) && hasLabel(h, HeaderLabelPartialIV) {
return errors.New("IV and PartialIV parameters must not both be present")
}
return nil
}

// ensureNoCritical ensures crit parameter is not present in the header.
//
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
func ensureNoCritical(h map[interface{}]interface{}) error {
if hasLabel(h, HeaderLabelCritical) {
return errors.New("unexpected crit parameter found")
}
return nil
}

// validateHeaderLabel validates if all header labels are integers or strings.
//
// label = int / tstr
//
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-1.4
func validateHeaderLabel(h map[interface{}]interface{}) error {
existing := make(map[interface{}]struct{})
for label := range h {
var ok bool
label, ok = normalizeLabel(label)
// validateHeaderParameters validates all headers conform to the spec.
func validateHeaderParameters(h map[interface{}]interface{}, protected bool) error {
existing := make(map[interface{}]struct{}, len(h))
for label, value := range h {
// Validate that all header labels are integers or strings.
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-1.4
label, ok := normalizeLabel(label)
if !ok {
return errors.New("cbor: header label: require int / tstr type")
return errors.New("header label: require int / tstr type")
}

// Validate that there are no duplicated labels.
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3
if _, ok := existing[label]; ok {
return fmt.Errorf("cbor: header label: duplicated label: %v", label)
return fmt.Errorf("header label: duplicated label: %v", label)
} else {
existing[label] = struct{}{}
}

// Validate the generic parameters.
// Reference: https://datatracker.ietf.org/doc/html/rfc8152#section-3.1
switch label {
case HeaderLabelAlgorithm:
_, is_alg := value.(Algorithm)
if !is_alg && !canInt(value) && !canTstr(value) {
return errors.New("header parameter: alg: require int / tstr type")
}
case HeaderLabelCritical:
if !protected {
return errors.New("header parameter: crit: not allowed")
}
if err := ensureCritical(value, h); err != nil {
return fmt.Errorf("header parameter: crit: %w", err)
}
case HeaderLabelContentType:
if !canTstr(value) && !canUint(value) {
return errors.New("header parameter: content type: require tstr / uint type")
}
case HeaderLabelKeyID:
if !canBstr(value) {
return errors.New("header parameter: kid: require bstr type")
}
case HeaderLabelIV:
if !canBstr(value) {
return errors.New("header parameter: IV: require bstr type")
}
if hasLabel(h, HeaderLabelPartialIV) {
return errors.New("header parameter: IV and PartialIV: parameters must not both be present")
}
case HeaderLabelPartialIV:
if !canBstr(value) {
return errors.New("header parameter: Partial IV: require bstr type")
}
if hasLabel(h, HeaderLabelIV) {
return errors.New("header parameter: IV and PartialIV: parameters must not both be present")
}
}
}
return nil
}

// canUint reports whether v can be used as a CBOR uint type.
func canUint(v interface{}) bool {
switch v := v.(type) {
case uint, uint8, uint16, uint32, uint64:
return true
case int:
return v >= 0
case int8:
return v >= 0
case int16:
return v >= 0
case int32:
return v >= 0
case int64:
return v >= 0
}
return false
}

// canInt reports whether v can be used as a CBOR int type.
func canInt(v interface{}) bool {
switch v.(type) {
case int, int8, int16, int32, int64,
uint, uint8, uint16, uint32, uint64:
return true
}
return false
}

// canTstr reports whether v can be used as a CBOR tstr type.
func canTstr(v interface{}) bool {
_, ok := v.(string)
return ok
}

// canBstr reports whether v can be used as a CBOR bstr type.
func canBstr(v interface{}) bool {
_, ok := v.([]byte)
return ok
}

// normalizeLabel tries to cast label into a int64 or a string.
// Returns (nil, false) if the label type is not valid.
func normalizeLabel(label interface{}) (interface{}, bool) {
Expand Down
Loading