From 2c1e0b36ee81c10e2fe7dcbb3c3c6a44608fff62 Mon Sep 17 00:00:00 2001 From: Jesse Thompson Date: Fri, 18 Aug 2023 16:59:05 +0000 Subject: [PATCH] Expose a Reader for TLV metadata --- pkg/tlv/tlv.go | 157 +++++++++++++++++++++++++++++++------------- pkg/tlv/tlv_test.go | 85 +++++++++++++++++++++++- 2 files changed, 194 insertions(+), 48 deletions(-) diff --git a/pkg/tlv/tlv.go b/pkg/tlv/tlv.go index 1501aa55..ea1909fe 100644 --- a/pkg/tlv/tlv.go +++ b/pkg/tlv/tlv.go @@ -1,9 +1,9 @@ package tlv import ( + "bytes" "encoding/binary" "io" - "math/big" "github.com/Azure/adx-mon/pkg/pool" ) @@ -16,7 +16,12 @@ type TLV struct { type Tag uint16 -var buf = pool.NewBytes(1024) +var ( + buf = pool.NewBytes(1024) + magicn = Tag(0x1) +) + +const sizeOfHeader = binary.MaxVarintLen16 /* T */ + binary.MaxVarintLen32 /* L */ + binary.MaxVarintLen32 /* V */ func New(tag Tag, value []byte) *TLV { @@ -38,65 +43,125 @@ func (t *TLV) Encode() []byte { // Encode the TLVs by prefixing a TLV as a header that // contains the number of TLVs contained within. func Encode(tlvs ...*TLV) []byte { - var b []byte - - // First create our header - v := append(big.NewInt(int64(len(tlvs))).Bytes(), byte(0)) - header := &TLV{ - Tag: 0x1, - Length: uint32(len(v)), - Value: v, - } - b = append(b, header.Encode()...) + var b bytes.Buffer - // Now append all our elements for _, t := range tlvs { - b = append(b, t.Encode()...) + b.Write(t.Encode()) } - return b + // Header is TLV where V is a uint32 instead of a byte slice. + // T is a magic number 0x1 + // L is the number of TLVs + // V is the size in bytes of all the TLVs + v := make([]byte, sizeOfHeader) + binary.BigEndian.PutUint16(v, uint16(magicn)) // T + binary.BigEndian.PutUint32(v[binary.MaxVarintLen16:], uint32(b.Len())) // L + binary.BigEndian.PutUint32(v[binary.MaxVarintLen16+binary.MaxVarintLen32:], uint32(len(tlvs))) // V + + return append(v, b.Bytes()...) } -func Decode(s io.ReadSeeker) ([]*TLV, error) { - data := buf.Get(1024) - defer buf.Put(data) - data = data[0:] +type Reader struct { + source io.Reader + discovered bool + header []TLV + buf []byte +} - _, err := s.Read(data) - if err != nil { +func NewReader(r io.Reader) *Reader { + return &Reader{source: r} +} + +func (r *Reader) Read(p []byte) (n int, err error) { + // extract our header + if !r.discovered { + if err := r.decode(); err != nil { + return 0, err + } + } + // drain + if len(r.buf) != 0 { + n = copy(p, r.buf) + r.buf = r.buf[n:] + return + } + // fast path + n, err = r.source.Read(p) + return +} + +func (r *Reader) Header() ([]TLV, error) { + if r.discovered { + return r.header, nil + } + + if err := r.decode(); err != nil { return nil, err } - header := &TLV{ - Tag: Tag(binary.BigEndian.Uint16(data[0:])), - Length: binary.BigEndian.Uint32(data[binary.MaxVarintLen16:]), + return r.header, nil +} + +func (r *Reader) decode() error { + p := buf.Get(sizeOfHeader) + defer buf.Put(p) + + n, err := r.source.Read(p) + if err != nil { + return err + } + + // source has no header + if Tag(binary.BigEndian.Uint16(p)) != magicn { + r.discovered = true + // we need to keep these bytes around until someone calls Read + r.buf = make([]byte, len(p)) + copy(r.buf, p) + return nil } - header.Value = data[binary.MaxVarintLen16+binary.MaxVarintLen32 : binary.MaxVarintLen16+binary.MaxVarintLen32+header.Length] - elements := int(big.NewInt(0).SetBytes(header.Value[:header.Length-1]).Int64()) - - // Now decode all the TLVs - var ( - tlvs []*TLV - offset = binary.MaxVarintLen16 + binary.MaxVarintLen32 + int(header.Length) - ) + offset := binary.MaxVarintLen16 + + sizeOfElements := binary.BigEndian.Uint32(p[offset:]) + offset += binary.MaxVarintLen32 + elements := int(binary.BigEndian.Uint32(p[offset:])) + offset += binary.MaxVarintLen32 + + // at this point we know how much data we need from our source, so fill the buffer + if n < int(sizeOfElements) { + // read the remaining bytes needed to extract our header + l := &io.LimitedReader{R: r.source, N: int64(int(sizeOfElements))} + var read []byte + read, err = io.ReadAll(l) + if err != nil { + + // we thought we had a header, but we just got unlucky + // with the first byte being our magic number. + if err == io.EOF { + r.discovered = true + r.buf = make([]byte, len(read)+n) + copy(r.buf, p) + copy(r.buf[n:], read) + return nil + } + return err + } + // resize + p = append(p, read...) + } + + // no bounds checks are necessary, all sizes are known + r.header = make([]TLV, elements) for i := 0; i < elements; i++ { - t := &TLV{} - t.Tag = Tag(binary.BigEndian.Uint16(data[offset:])) + t := TLV{} + t.Tag = Tag(binary.BigEndian.Uint16(p[offset:])) offset += binary.MaxVarintLen16 - t.Length = binary.BigEndian.Uint32(data[offset:]) + t.Length = binary.BigEndian.Uint32(p[offset:]) offset += binary.MaxVarintLen32 - if offset+int(t.Length) > len(data) { - break - } - t.Value = data[offset : offset+int(t.Length)] + t.Value = p[offset : offset+int(t.Length)] offset += int(t.Length) - tlvs = append(tlvs, t) - } - - // Seek past our TLVs and at the beginning of our payload - if _, err := s.Seek(int64(offset), io.SeekStart); err != nil { - return nil, err + r.header[i] = t } - return tlvs, nil + r.discovered = true + return nil } diff --git a/pkg/tlv/tlv_test.go b/pkg/tlv/tlv_test.go index 5955f674..b401e790 100644 --- a/pkg/tlv/tlv_test.go +++ b/pkg/tlv/tlv_test.go @@ -1,8 +1,10 @@ package tlv_test import ( + "bytes" "io" "os" + "reflect" "testing" "github.com/Azure/adx-mon/pkg/tlv" @@ -30,12 +32,91 @@ func TestTLV(t *testing.T) { require.NoError(t, err) defer tf.Close() - tlvs, err := tlv.Decode(tf) + r := tlv.NewReader(tf) + tlvs, err := r.Header() require.NoError(t, err) require.Equal(t, 1, len(tlvs)) require.Equal(t, v, string(tlvs[0].Value)) - data, err := io.ReadAll(tf) + data, err := io.ReadAll(r) require.NoError(t, err) require.Equal(t, randomBytes, data) } + +func TestReader(t *testing.T) { + tests := []struct { + Name string + HeaderLen int + SkipHeader bool + }{ + { + Name: "single header entry", + HeaderLen: 1, + }, + { + Name: "this payload contains no tlv header", + }, + { + Name: "Invoke Read without first invoking Header even though TLVs exist", + HeaderLen: 2, + SkipHeader: true, + }, + { + Name: "Several header entries", + HeaderLen: 5, + }, + } + for _, tt := range tests { + t.Run(tt.Name, func(t *testing.T) { + b := []byte(tt.Name) + + if tt.HeaderLen != 0 { + var tlvs []*tlv.TLV + for i := 0; i < tt.HeaderLen; i++ { + tlvs = append(tlvs, tlv.New(tlv.Tag(i), []byte(tt.Name))) + } + b = append(tlv.Encode(tlvs...), b...) + } + + source := bytes.NewBuffer(b) + r := tlv.NewReader(source) + + if !tt.SkipHeader { + h, err := r.Header() + require.NoError(t, err) + require.Equal(t, tt.HeaderLen, len(h)) + for _, metadata := range h { + require.Equal(t, tt.Name, string(metadata.Value)) + } + } + + have, err := io.ReadAll(r) + require.NoError(t, err) + require.Equal(t, tt.Name, string(have)) + }) + } +} + +func TestUnluckyMagicNumber(t *testing.T) { + var b bytes.Buffer + _, err := b.Write([]byte{0x1}) + require.NoError(t, err) + _, err = b.WriteString("I kind of look like a TLV") + require.NoError(t, err) + r := tlv.NewReader(bytes.NewBuffer(b.Bytes())) + have, err := io.ReadAll(r) + require.NoError(t, err) + require.True(t, reflect.DeepEqual(have, b.Bytes())) +} + +func BenchmarkReader(b *testing.B) { + t := tlv.New(tlv.Tag(0x2), []byte("some tag payload")) + h := tlv.Encode(t) + p := bytes.NewReader(append(h, []byte("body payload")...)) + r := tlv.NewReader(p) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + io.ReadAll(r) + } +}