Skip to content

Commit

Permalink
Expose a Reader for TLV metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
Jesse Thompson committed Aug 18, 2023
1 parent 14a0d3d commit b8bf573
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 29 deletions.
102 changes: 75 additions & 27 deletions pkg/tlv/tlv.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tlv

import (
"bytes"
"encoding/binary"
"io"
"math/big"
Expand All @@ -16,7 +17,10 @@ type TLV struct {

type Tag uint16

var buf = pool.NewBytes(1024)
var (
buf = pool.NewBytes(1024)
magicn = Tag(0x1)
)

func New(tag Tag, value []byte) *TLV {

Expand All @@ -43,7 +47,7 @@ func Encode(tlvs ...*TLV) []byte {
// First create our header
v := append(big.NewInt(int64(len(tlvs))).Bytes(), byte(0))
header := &TLV{
Tag: 0x1,
Tag: magicn,
Length: uint32(len(v)),
Value: v,
}
Expand All @@ -57,46 +61,90 @@ func Encode(tlvs ...*TLV) []byte {
return b
}

func Decode(s io.ReadSeeker) ([]*TLV, error) {
data := buf.Get(1024)
defer buf.Put(data)
data = data[0:]
type Reader struct {
source io.Reader
discovered bool
header []*TLV
buf []byte
offset int
}

func NewReader(r io.Reader) *Reader {
return &Reader{source: r}
}

// Read implements io.Reader
func (r *Reader) Read(p []byte) (n int, err error) {
if len(r.buf) > 0 {
n = copy(p, r.buf)
r.buf = r.buf[n:]
return
}
if r.discovered {
n, err = r.source.Read(p)
return
}

_, err := s.Read(data)
r.buf = make([]byte, len(p))
n, err = r.source.Read(r.buf)
if err != nil {
return nil, err
return
}

ht := Tag(binary.BigEndian.Uint16(r.buf[0:]))
// source has no header
if ht != magicn {
r.discovered = true
copy(p, r.buf)
return
}

header := &TLV{
Tag: Tag(binary.BigEndian.Uint16(data[0:])),
Length: binary.BigEndian.Uint32(data[binary.MaxVarintLen16:]),
Tag: ht,
Length: binary.BigEndian.Uint32(r.buf[binary.MaxVarintLen16:]),
}
header.Value = data[binary.MaxVarintLen16+binary.MaxVarintLen32 : binary.MaxVarintLen16+binary.MaxVarintLen32+header.Length]

header.Value = r.buf[binary.MaxVarintLen16+binary.MaxVarintLen32 : binary.MaxVarintLen16+binary.MaxVarintLen32+header.Length]
elements := int(big.NewInt(0).SetBytes(header.Value[:header.Length-1]).Int64())

// Now decode all the TLVs
var (
tlvs []*TLV
offset = binary.MaxVarintLen16 + binary.MaxVarintLen32 + int(header.Length)
)
r.offset = binary.MaxVarintLen16 + binary.MaxVarintLen32 + int(header.Length)

for i := 0; i < elements; i++ {
t := &TLV{}
t.Tag = Tag(binary.BigEndian.Uint16(data[offset:]))
offset += binary.MaxVarintLen16
t.Length = binary.BigEndian.Uint32(data[offset:])
offset += binary.MaxVarintLen32
if offset+int(t.Length) > len(data) {
t.Tag = Tag(binary.BigEndian.Uint16(r.buf[r.offset:]))
r.offset += binary.MaxVarintLen16
t.Length = binary.BigEndian.Uint32(r.buf[r.offset:])
r.offset += binary.MaxVarintLen32
if r.offset+int(t.Length) > len(r.buf) {
break
}
t.Value = data[offset : offset+int(t.Length)]
offset += int(t.Length)
tlvs = append(tlvs, t)
t.Value = r.buf[r.offset : r.offset+int(t.Length)]
r.offset += int(t.Length)
r.header = append(r.header, t)
}

copy(p, r.buf[r.offset:])
r.buf = nil
n -= r.offset

r.discovered = true
return
}

func (r *Reader) Header() ([]*TLV, error) {
if r.discovered {
return r.header, nil
}

// Seek past our TLVs and at the beginning of our payload
if _, err := s.Seek(int64(offset), io.SeekStart); err != nil {
b := buf.Get(128)
defer buf.Put(b)

_, err := r.Read(b)
if err != nil {
return nil, err
}

return tlvs, nil
r.buf = bytes.Trim(b, "\x00")

return r.header, nil
}
56 changes: 54 additions & 2 deletions pkg/tlv/tlv_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package tlv_test

import (
"bytes"
"io"
"os"
"testing"
Expand Down Expand Up @@ -30,12 +31,63 @@ func TestTLV(t *testing.T) {
require.NoError(t, err)
defer tf.Close()

tlvs, err := tlv.Decode(tf)
r := tlv.NewReader(tf)
tlvs, err := r.Header()
require.NoError(t, err)
require.Equal(t, 1, len(tlvs))
require.Equal(t, v, string(tlvs[0].Value))

data, err := io.ReadAll(tf)
data, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, randomBytes, data)
}

func TestReader(t *testing.T) {
tests := []struct {
Name string
HeaderLen int
SkipHeader bool
}{
{
Name: "single header entry",
HeaderLen: 1,
},
{
Name: "no header",
},
{
Name: "Invoke Read without first invoking Header even though TLVs exist",
HeaderLen: 2,
SkipHeader: true,
},
}
for _, tt := range tests {
t.Run(tt.Name, func(t *testing.T) {
b := []byte(tt.Name)

if tt.HeaderLen != 0 {
var tlvs []*tlv.TLV
for i := 0; i < tt.HeaderLen; i++ {
tlvs = append(tlvs, tlv.New(tlv.Tag(i), []byte(tt.Name)))
}
b = append(tlv.Encode(tlvs...), b...)
}

source := bytes.NewBuffer(b)
r := tlv.NewReader(source)

if !tt.SkipHeader {
h, err := r.Header()
require.NoError(t, err)
require.Equal(t, tt.HeaderLen, len(h))
for _, metadata := range h {
require.Equal(t, tt.Name, string(metadata.Value))
}
}

have, err := io.ReadAll(r)
require.NoError(t, err)
require.Equal(t, tt.Name, string(have))
})
}
}

0 comments on commit b8bf573

Please sign in to comment.