diff --git a/attribute.go b/attribute.go index b2ac945..61112ae 100644 --- a/attribute.go +++ b/attribute.go @@ -107,26 +107,27 @@ func MarshalAttributes(attrs []Attribute) ([]byte, error) { // It is recommend to use the AttributeDecoder type where possible instead of calling // UnmarshalAttributes and using package nlenc functions directly. func UnmarshalAttributes(b []byte) ([]Attribute, error) { - var attrs []Attribute - var i int - for { - if i > len(b) || len(b[i:]) == 0 { - break - } - var a Attribute - if err := (&a).unmarshal(b[i:]); err != nil { - return nil, err - } + ad, err := NewAttributeDecoder(b) + if err != nil { + return nil, err + } - if a.Length == 0 { - i += nlaHeaderLen - continue - } + // Return a nil slice when there are no attributes to decode. + if ad.Len() == 0 { + return nil, nil + } - i += nlaAlign(int(a.Length)) + attrs := make([]Attribute, 0, ad.Len()) + + for ad.Next() { + if ad.attr().Length != 0 { + attrs = append(attrs, ad.attr()) + } + } - attrs = append(attrs, a) + if err := ad.Err(); err != nil { + return nil, err } return attrs, nil @@ -148,10 +149,14 @@ type AttributeDecoder struct { // If not set, the native byte order will be used. ByteOrder binary.ByteOrder - // The attributes being worked on, and the iterator index into the slice of - // attributes. - attrs []Attribute - i int + // The current attribute being worked on. + a Attribute + + // The slice of input bytes and its iterator index. + b []byte + i int + + length int // Any error encountered while decoding attributes. err error @@ -160,17 +165,21 @@ type AttributeDecoder struct { // NewAttributeDecoder creates an AttributeDecoder that unpacks Attributes // from b and prepares the decoder for iteration. func NewAttributeDecoder(b []byte) (*AttributeDecoder, error) { - attrs, err := UnmarshalAttributes(b) - if err != nil { - return nil, err - } - return &AttributeDecoder{ + ad := &AttributeDecoder{ // By default, use native byte order. ByteOrder: nlenc.NativeEndian(), - attrs: attrs, - }, nil + b: b, + } + + var err error + ad.length, err = ad.available() + if err != nil { + return nil, err + } + + return ad, nil } // Next advances the decoder to the next netlink attribute. It returns false @@ -181,34 +190,93 @@ func (ad *AttributeDecoder) Next() bool { return false } - ad.i++ + // Exit if array pointer is at or beyond the end of the slice. + if ad.i >= len(ad.b) { + return false + } + + if err := ad.a.unmarshal(ad.b[ad.i:]); err != nil { + ad.err = err + return false + } + + // Advance the pointer by at least one header's length. + if int(ad.a.Length) < nlaHeaderLen { + ad.i += nlaHeaderLen + } else { + ad.i += nlaAlign(int(ad.a.Length)) + } - // More attributes? - return len(ad.attrs) >= ad.i + return true } // Type returns the Attribute.Type field of the current netlink attribute // pointed to by the decoder. // // Type masks off the high bits of the netlink attribute type which may contain -// the Nested and NetByteOrder flags. If you need direct access to these flags, -// consider using UnmarshalAttributes instead. +// the Nested and NetByteOrder flags. These can be obtained by calling TypeFlags. func (ad *AttributeDecoder) Type() uint16 { // Mask off any flags stored in the high bits. - return ad.attr().Type & attrTypeMask + return ad.a.Type & attrTypeMask +} + +// TypeFlags returns the two high bits of the Attribute.Type field of the current +// netlink attribute pointed to by the decoder. +// +// These bits of the netlink attribute type are used for the Nested and NetByteOrder +// flags, available as the Nested and NetByteOrder constants in this package. +func (ad *AttributeDecoder) TypeFlags() uint16 { + return ad.a.Type & ^attrTypeMask } // Len returns the number of netlink attributes pointed to by the decoder. -func (ad *AttributeDecoder) Len() int { return len(ad.attrs) } +func (ad *AttributeDecoder) Len() int { return ad.length } + +// count scans the input slice to count the number of netlink attributes +// that could be decoded by Next(). +func (ad *AttributeDecoder) available() (int, error) { + + var i, count int + for { + + // No more data to read. + if i >= len(ad.b) { + break + } + + // Make sure there's at least a header's worth + // of data to read on each iteration. + if len(ad.b[i:]) < nlaHeaderLen { + return 0, errInvalidAttribute + } + + // Extract the length of the attribute. + l := int(nlenc.Uint16(ad.b[i : i+2])) + + // Ignore zero-length attributes. + if l != 0 { + count++ + } + + // Advance by at least a header's worth of bytes. + if l < nlaHeaderLen { + l = nlaHeaderLen + } + + i += nlaAlign(l) + } + + return count, nil +} // attr returns the current Attribute pointed to by the decoder. func (ad *AttributeDecoder) attr() Attribute { - return ad.attrs[ad.i-1] + return ad.a } // data returns the Data field of the current Attribute pointed to by the decoder. func (ad *AttributeDecoder) data() []byte { - return ad.attr().Data + return ad.a.Data } // Err returns the first error encountered by the decoder. diff --git a/attribute_test.go b/attribute_test.go index eeecab2..01f9a35 100644 --- a/attribute_test.go +++ b/attribute_test.go @@ -754,6 +754,22 @@ func TestAttributeDecoderOK(t *testing.T) { }) }, }, + { + name: "typeflags", + attrs: []Attribute{{ + Type: 0xffff, + }}, + fn: func(ad *AttributeDecoder) { + + if diff := cmp.Diff(ad.Type(), uint16(0x3fff)); diff != "" { + panicf("unexpected Type (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(ad.TypeFlags(), uint16(0xc000)); diff != "" { + panicf("unexpected TypeFlags (-want +got):\n%s", diff) + } + }, + }, } for _, tt := range tests {