diff --git a/protocol/binary/stream_writer.go b/protocol/binary/stream_writer.go
new file mode 100644
index 00000000..76785e64
--- /dev/null
+++ b/protocol/binary/stream_writer.go
@@ -0,0 +1,250 @@
+// Copyright (c) 2021 Uber Technologies, Inc.
+//
+// Permission is hereby granted, free of charge, to any person obtaining a copy
+// of this software and associated documentation files (the "Software"), to deal
+// in the Software without restriction, including without limitation the rights
+// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+// copies of the Software, and to permit persons to whom the Software is
+// furnished to do so, subject to the following conditions:
+//
+// The above copyright notice and this permission notice shall be included in
+// all copies or substantial portions of the Software.
+//
+// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+// THE SOFTWARE.
+
+package binary
+
+import (
+ "io"
+ "math"
+ "sync"
+
+ "go.uber.org/thriftrw/internal/iface"
+ "go.uber.org/thriftrw/protocol/stream"
+)
+
+var streamWriterPool = sync.Pool{New: func() interface{} {
+ writer := &StreamWriter{}
+ return writer
+}}
+
+// StreamWriter implements basic logic for writing the Thrift Binary Protocol
+// to an io.Writer.
+type StreamWriter struct {
+ // Private implementation to disallow custom implementations of
+ // the Writer interface
+ iface.Impl
+
+ writer io.Writer
+
+ // This buffer is re-used every time we need a slice of up to 8 bytes.
+ buffer [8]byte
+}
+
+// BorrowStreamWriter fetches a Writer from the system that will write its
+// output to the given io.Writer.
+//
+// This StreamWriter must be returned back using ReturnStreamWriter.
+func BorrowStreamWriter(w io.Writer) *StreamWriter {
+ writer := streamWriterPool.Get().(*StreamWriter)
+ writer.writer = w
+ return writer
+}
+
+// ReturnStreamWriter returns a previously borrowed StreamWriter back to the
+// system.
+func ReturnStreamWriter(w *StreamWriter) {
+ w.writer = nil
+ streamWriterPool.Put(w)
+}
+
+func (bw *StreamWriter) write(bs []byte) error {
+ _, err := bw.writer.Write(bs)
+ return err
+}
+
+func (bw *StreamWriter) writeByte(b byte) error {
+ bs := bw.buffer[0:1]
+ bs[0] = b
+ return bw.write(bs)
+}
+
+func (bw *StreamWriter) writeInt16(n int16) error {
+ bs := bw.buffer[0:2]
+ bigEndian.PutUint16(bs, uint16(n))
+ return bw.write(bs)
+}
+
+func (bw *StreamWriter) writeInt32(n int32) error {
+ bs := bw.buffer[0:4]
+ bigEndian.PutUint32(bs, uint32(n))
+ return bw.write(bs)
+}
+
+func (bw *StreamWriter) writeInt64(n int64) error {
+ bs := bw.buffer[0:8]
+ bigEndian.PutUint64(bs, uint64(n))
+ return bw.write(bs)
+}
+
+func (bw *StreamWriter) writeString(s string) error {
+ if err := bw.writeInt32(int32(len(s))); err != nil {
+ return err
+ }
+
+ _, err := io.WriteString(bw.writer, s)
+ return err
+}
+
+// WriteBool encodes a boolean
+func (bw *StreamWriter) WriteBool(b bool) error {
+ if b {
+ return bw.writeByte(1)
+ }
+ return bw.writeByte(0)
+}
+
+// WriteInt8 encodes an int8
+func (bw *StreamWriter) WriteInt8(i int8) error {
+ return bw.writeByte(byte(i))
+}
+
+// WriteInt16 encodes an int16
+func (bw *StreamWriter) WriteInt16(i int16) error {
+ return bw.writeInt16(i)
+}
+
+// WriteInt32 encodes an int32
+func (bw *StreamWriter) WriteInt32(i int32) error {
+ return bw.writeInt32(i)
+}
+
+// WriteInt64 encodes an int64
+func (bw *StreamWriter) WriteInt64(i int64) error {
+ return bw.writeInt64(i)
+}
+
+// WriteString encodes a string
+func (bw *StreamWriter) WriteString(s string) error {
+ return bw.writeString(s)
+}
+
+// WriteDouble encodes a double
+func (bw *StreamWriter) WriteDouble(d float64) error {
+ value := math.Float64bits(d)
+ return bw.writeInt64(int64(value))
+}
+
+// WriteBinary encodes binary
+func (bw *StreamWriter) WriteBinary(b []byte) error {
+ if err := bw.writeInt32(int32(len(b))); err != nil {
+ return err
+ }
+ return bw.write(b)
+}
+
+// WriteFieldBegin marks the beginning of a new field in a struct. The first
+// byte denotes the type and the next two bytes denote the field id.
+func (bw *StreamWriter) WriteFieldBegin(f stream.FieldHeader) error {
+ // type:1
+ if err := bw.writeByte(byte(f.Type)); err != nil {
+ return err
+ }
+
+ // id:2
+ if err := bw.writeInt16(f.ID); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// WriteFieldEnd denotes the end of a field. No-op.
+func (bw *StreamWriter) WriteFieldEnd() error {
+ return nil
+}
+
+// WriteStructBegin denotes the beginning of a struct. No-op.
+func (bw *StreamWriter) WriteStructBegin() error {
+ return nil
+}
+
+// WriteStructEnd uses the zero byte to mark the end of a struct.
+func (bw *StreamWriter) WriteStructEnd() error {
+ return bw.writeByte(0) // end struct
+}
+
+// WriteListBegin marks the beginning of a new list. The first byte denotes
+// the type of the items and the next four bytes denote the length of the list.
+func (bw *StreamWriter) WriteListBegin(l stream.ListHeader) error {
+ // vtype:1
+ if err := bw.writeByte(byte(l.Type)); err != nil {
+ return err
+ }
+
+ // length:4
+ if err := bw.writeInt32(int32(l.Length)); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// WriteListEnd marks the end of a list. No-op.
+func (bw *StreamWriter) WriteListEnd() error {
+ return nil
+}
+
+// WriteSetBegin marks the beginning of a new set. The first byte denotes
+// the type of the items and the next four bytes denote the length of the set.
+func (bw *StreamWriter) WriteSetBegin(s stream.SetHeader) error {
+ // vtype:1
+ if err := bw.writeByte(byte(s.Type)); err != nil {
+ return err
+ }
+
+ // length:4
+ if err := bw.writeInt32(int32(s.Length)); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// WriteSetEnd marks the end of a set. No-op.
+func (bw *StreamWriter) WriteSetEnd() error {
+ return nil
+}
+
+// WriteMapBegin marks the beginning of a new map. The first byte denotes
+// the type of the keys, the second byte denotes the type of the values,
+// and the next four bytes denote the length of the map.
+func (bw *StreamWriter) WriteMapBegin(m stream.MapHeader) error {
+ // ktype:1
+ if err := bw.writeByte(byte(m.KeyType)); err != nil {
+ return err
+ }
+
+ // vtype:1
+ if err := bw.writeByte(byte(m.ValueType)); err != nil {
+ return err
+ }
+
+ // length:4
+ if err := bw.writeInt32(int32(m.Length)); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// WriteMapEnd marks the end of a map. No-op.
+func (bw *StreamWriter) WriteMapEnd() error {
+ return nil
+}
diff --git a/protocol/binary_test.go b/protocol/binary_test.go
index a01cff5f..612e18fe 100644
--- a/protocol/binary_test.go
+++ b/protocol/binary_test.go
@@ -23,6 +23,7 @@ package protocol
import (
"bytes"
"fmt"
+ "go.uber.org/thriftrw/protocol/stream"
"io"
"math"
"reflect"
@@ -40,6 +41,12 @@ type encodeDecodeTest struct {
encoded []byte
}
+type streamEncodeDecodeTest struct {
+ description string
+ value testValue
+ encoded []byte
+}
+
func checkEncodeDecode(t *testing.T, typ wire.Type, tests []encodeDecodeTest) {
for _, tt := range tests {
buffer := bytes.Buffer{}
@@ -68,6 +75,18 @@ func checkEncodeDecode(t *testing.T, typ wire.Type, tests []encodeDecodeTest) {
}
}
+func checkStreamEncode(t *testing.T, val testValue, expected []byte) {
+ var buff bytes.Buffer
+
+ // Encode with Streaming protocol
+ w := binary.BorrowStreamWriter(&buff)
+ err := val.StreamEncode(w)
+ require.NoError(t, err)
+ binary.ReturnStreamWriter(w)
+
+ assert.Equal(t, expected, buff.Bytes())
+}
+
type failureTest []byte
func checkDecodeFailure(t *testing.T, typ wire.Type, tests []failureTest) {
@@ -132,6 +151,19 @@ func TestBoolEOFFailure(t *testing.T) {
checkEOFError(t, wire.TBool, tests)
}
+func TestStreamBool(t *testing.T) {
+ tests := []streamEncodeDecodeTest{
+ {"false", testvalue(vbool(false)), []byte{0x00}},
+ {"true", testvalue(vbool(true)), []byte{0x01}},
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(tt *testing.T) {
+ checkStreamEncode(tt, test.value, test.encoded)
+ })
+ }
+}
+
func TestI8(t *testing.T) {
tests := []encodeDecodeTest{
{vi8(0), []byte{0x00}},
@@ -152,6 +184,22 @@ func TestI8EOFFailure(t *testing.T) {
checkEOFError(t, wire.TI8, tests)
}
+func TestStreamI8(t *testing.T) {
+ tests := []streamEncodeDecodeTest{
+ {"0", testvalue(vi8(0)), []byte{0x00}},
+ {"1", testvalue(vi8(1)), []byte{0x01}},
+ {"-1", testvalue(vi8(-1)), []byte{0xff}},
+ {"127", testvalue(vi8(127)), []byte{0x7f}},
+ {"-128", testvalue(vi8(-128)), []byte{0x80}},
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(tt *testing.T) {
+ checkStreamEncode(tt, test.value, test.encoded)
+ })
+ }
+}
+
func TestI16(t *testing.T) {
tests := []encodeDecodeTest{
{vi16(1), []byte{0x00, 0x01}},
@@ -178,6 +226,27 @@ func TestI16EOFFailure(t *testing.T) {
checkEOFError(t, wire.TI16, tests)
}
+func TestStreamI16(t *testing.T) {
+ tests := []streamEncodeDecodeTest{
+ {"1", testvalue(vi16(1)), []byte{0x00, 0x01}},
+ {"255", testvalue(vi16(255)), []byte{0x00, 0xff}},
+ {"256", testvalue(vi16(256)), []byte{0x01, 0x00}},
+ {"257", testvalue(vi16(257)), []byte{0x01, 0x01}},
+ {"32767", testvalue(vi16(32767)), []byte{0x7f, 0xff}},
+ {"-1", testvalue(vi16(-1)), []byte{0xff, 0xff}},
+ {"-2", testvalue(vi16(-2)), []byte{0xff, 0xfe}},
+ {"-256", testvalue(vi16(-256)), []byte{0xff, 0x00}},
+ {"-255", testvalue(vi16(-255)), []byte{0xff, 0x01}},
+ {"-32768", testvalue(vi16(-32768)), []byte{0x80, 0x00}},
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(tt *testing.T) {
+ checkStreamEncode(tt, test.value, test.encoded)
+ })
+ }
+}
+
func TestI32(t *testing.T) {
tests := []encodeDecodeTest{
{vi32(1), []byte{0x00, 0x00, 0x00, 0x01}},
@@ -204,6 +273,27 @@ func TestI32EOFFailure(t *testing.T) {
checkEOFError(t, wire.TI32, tests)
}
+func TestStreamI32(t *testing.T) {
+ tests := []streamEncodeDecodeTest{
+ {"1", testvalue(vi32(1)), []byte{0x00, 0x00, 0x00, 0x01}},
+ {"255", testvalue(vi32(255)), []byte{0x00, 0x00, 0x00, 0xff}},
+ {"65535", testvalue(vi32(65535)), []byte{0x00, 0x00, 0xff, 0xff}},
+ {"16777215", testvalue(vi32(16777215)), []byte{0x00, 0xff, 0xff, 0xff}},
+ {"2147483647", testvalue(vi32(2147483647)), []byte{0x7f, 0xff, 0xff, 0xff}},
+ {"-1", testvalue(vi32(-1)), []byte{0xff, 0xff, 0xff, 0xff}},
+ {"-256", testvalue(vi32(-256)), []byte{0xff, 0xff, 0xff, 0x00}},
+ {"-65536", testvalue(vi32(-65536)), []byte{0xff, 0xff, 0x00, 0x00}},
+ {"-16777216", testvalue(vi32(-16777216)), []byte{0xff, 0x00, 0x00, 0x00}},
+ {"-2147483648", testvalue(vi32(-2147483648)), []byte{0x80, 0x00, 0x00, 0x00}},
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(tt *testing.T) {
+ checkStreamEncode(tt, test.value, test.encoded)
+ })
+ }
+}
+
func TestI64(t *testing.T) {
tests := []encodeDecodeTest{
{vi64(1), []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}},
@@ -232,6 +322,29 @@ func TestI64EOFFailure(t *testing.T) {
checkEOFError(t, wire.TI64, tests)
}
+func TestStreamI64(t *testing.T) {
+ tests := []streamEncodeDecodeTest{
+ {"1", testvalue(vi64(1)), []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}},
+ {"4294967295", testvalue(vi64(4294967295)), []byte{0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff}},
+ {"1099511627775", testvalue(vi64(1099511627775)), []byte{0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff}},
+ {"281474976710655", testvalue(vi64(281474976710655)), []byte{0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
+ {"72057594037927935", testvalue(vi64(72057594037927935)), []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
+ {"9223372036854775807", testvalue(vi64(9223372036854775807)), []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
+ {"-1", testvalue(vi64(-1)), []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
+ {"-4294967296", testvalue(vi64(-4294967296)), []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00}},
+ {"-1099511627776", testvalue(vi64(-1099511627776)), []byte{0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00}},
+ {"-281474976710656", testvalue(vi64(-281474976710656)), []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
+ {"-72057594037927936", testvalue(vi64(-72057594037927936)), []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
+ {"-9223372036854775808", testvalue(vi64(-9223372036854775808)), []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(tt *testing.T) {
+ checkStreamEncode(tt, test.value, test.encoded)
+ })
+ }
+}
+
func TestDouble(t *testing.T) {
tests := []encodeDecodeTest{
{vdouble(0.0), []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
@@ -257,6 +370,26 @@ func TestDoubleEOFFailure(t *testing.T) {
checkEOFError(t, wire.TDouble, tests)
}
+func TestStreamDouble(t *testing.T) {
+ tests := []streamEncodeDecodeTest{
+ {"0.0", testvalue(vdouble(0.0)), []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
+ {"1.0", testvalue(vdouble(1.0)), []byte{0x3f, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}},
+ {"1.0000000001", testvalue(vdouble(1.0000000001)), []byte{0x3f, 0xf0, 0x0, 0x0, 0x0, 0x6, 0xdf, 0x38}},
+ {"1.1", testvalue(vdouble(1.1)), []byte{0x3f, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a}},
+ {"-1.1", testvalue(vdouble(-1.1)), []byte{0xbf, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a}},
+ {"3.141592653589793", testvalue(vdouble(3.141592653589793)), []byte{0x40, 0x9, 0x21, 0xfb, 0x54, 0x44, 0x2d, 0x18}},
+ {"-1.0000000001", testvalue(vdouble(-1.0000000001)), []byte{0xbf, 0xf0, 0x0, 0x0, 0x0, 0x6, 0xdf, 0x38}},
+ {"0", testvalue(vdouble(math.Inf(0))), []byte{0x7f, 0xf0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}},
+ {"-1", testvalue(vdouble(math.Inf(-1))), []byte{0xff, 0xf0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}},
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(tt *testing.T) {
+ checkStreamEncode(tt, test.value, test.encoded)
+ })
+ }
+}
+
func TestDoubleNaN(t *testing.T) {
value := vdouble(math.NaN())
encoded := []byte{0x7f, 0xf8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}
@@ -274,6 +407,13 @@ func TestDoubleNaN(t *testing.T) {
}
}
+func TestStreamDoubleNaN(t *testing.T) {
+ value := testvalue(vdouble(math.NaN()))
+ encoded := []byte{0x7f, 0xf8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1}
+
+ checkStreamEncode(t, value, encoded)
+}
+
func TestBinary(t *testing.T) {
tests := []encodeDecodeTest{
{vbinary(""), []byte{0x00, 0x00, 0x00, 0x00}},
@@ -286,6 +426,22 @@ func TestBinary(t *testing.T) {
checkEncodeDecode(t, wire.TBinary, tests)
}
+func TestStreamBinary(t *testing.T) {
+ tests := []streamEncodeDecodeTest{
+ {"empty string", testvalue(vbinary("")), []byte{0x00, 0x00, 0x00, 0x00}},
+ {"hello ", testvalue(vbinary("hello")), []byte{
+ 0x00, 0x00, 0x00, 0x05, // len:4 = 5
+ 0x68, 0x65, 0x6c, 0x6c, 0x6f, // 'h', 'e', 'l', 'l', 'o'
+ }},
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(tt *testing.T) {
+ checkStreamEncode(tt, test.value, test.encoded)
+ })
+ }
+}
+
func TestBinaryLargeLength(t *testing.T) {
// 5 MB + 4 bytes for length
data := make([]byte, 5242880+4)
@@ -376,6 +532,86 @@ func TestStruct(t *testing.T) {
checkEncodeDecode(t, wire.TStruct, tests)
}
+func TestStructBeginAndEndEncode(t *testing.T) {
+ var streamBuff bytes.Buffer
+ var err error
+
+ // Encode with Streaming protocol
+ w := binary.BorrowStreamWriter(&streamBuff)
+ err = w.WriteStructBegin()
+ assert.NoError(t, err)
+ err = w.WriteStructEnd()
+ assert.NoError(t, err)
+ binary.ReturnStreamWriter(w)
+
+ // Assert that encoded bytes are equivalent
+ assert.Equal(t, []byte{0x0}, streamBuff.Bytes())
+}
+
+func TestStreamStruct(t *testing.T) {
+ tests := []streamEncodeDecodeTest{
+ {"empty struct", testvalue(vstruct()), []byte{0x00}},
+ {"simple struct", testvalue(vstruct(vfield(1, vbool(true)))), []byte{
+ 0x02, // type:1 = bool
+ 0x00, 0x01, // id:2 = 1
+ 0x01, // value = true
+ 0x00, // stop
+ }},
+ {
+ "complex struct",
+ testvalue(vstruct(
+ vfield(1, vi16(42)),
+ vfield(2, vlist(wire.TBinary, vbinary("foo"), vbinary("bar"))),
+ vfield(3, vset(wire.TBinary, vbinary("baz"), vbinary("qux"))),
+ )), []byte{
+ 0x06, // type:1 = i16
+ 0x00, 0x01, // id:2 = 1
+ 0x00, 0x2a, // value = 42
+
+ 0x0F, // type:1 = list
+ 0x00, 0x02, // id:2 = 2
+
+ //
+ 0x0B, // type:1 = binary
+ 0x00, 0x00, 0x00, 0x02, // size:4 = 2
+ //
+ 0x00, 0x00, 0x00, 0x03, // len:4 = 3
+ 0x66, 0x6f, 0x6f, // 'f', 'o', 'o'
+ //
+ //
+ 0x00, 0x00, 0x00, 0x03, // len:4 = 3
+ 0x62, 0x61, 0x72, // 'b', 'a', 'r'
+ //
+ //
+
+ 0x0E, // type = set
+ 0x00, 0x03, // id = 3
+
+ //
+ 0x0B, // type:1 = binary
+ 0x00, 0x00, 0x00, 0x02, // size:4 = 2
+ //
+ 0x00, 0x00, 0x00, 0x03, // len:4 = 3
+ 0x62, 0x61, 0x7a, // 'b', 'a', 'z'
+ //
+ //
+ 0x00, 0x00, 0x00, 0x03, // len:4 = 3
+ 0x71, 0x75, 0x78, // 'q', 'u', 'x'
+ //
+ //
+
+ 0x00, // stop
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(tt *testing.T) {
+ checkStreamEncode(tt, test.value, test.encoded)
+ })
+ }
+}
+
func TestStructEOFFailure(t *testing.T) {
tests := []failureTest{
{},
@@ -430,6 +666,73 @@ func TestMap(t *testing.T) {
checkEncodeDecode(t, wire.TMap, tests)
}
+func TestMapBeginEncode(t *testing.T) {
+ var streamBuff bytes.Buffer
+ var err error
+
+ // Encode with Streaming protocol
+ w := binary.BorrowStreamWriter(&streamBuff)
+ err = w.WriteMapBegin(stream.MapHeader{
+ KeyType: wire.TBinary,
+ ValueType: wire.TBool,
+ Length: 1,
+ })
+ assert.NoError(t, err)
+ binary.ReturnStreamWriter(w)
+
+ // Assert that encoded bytes are equivalent
+ assert.Equal(t, []byte{0xb, 0x2, 0x0, 0x0, 0x0, 0x1}, streamBuff.Bytes())
+}
+
+func TestStreamMap(t *testing.T) {
+ tests := []streamEncodeDecodeTest{
+ {"small map", testvalue(vmap(wire.TI64, wire.TBinary)), []byte{0x0A, 0x0B, 0x00, 0x00, 0x00, 0x00}},
+ {
+ "complex map",
+ testvalue(vmap(
+ wire.TBinary, wire.TList,
+ vitem(vbinary("a"), vlist(wire.TI16, vi16(1))),
+ vitem(vbinary("b"), vlist(wire.TI16, vi16(2), vi16(3))),
+ )), []byte{
+ 0x0B, // ktype = binary
+ 0x0F, // vtype = list
+ 0x00, 0x00, 0x00, 0x02, // count:4 = 2
+
+ // -
+ //
+ 0x00, 0x00, 0x00, 0x01, // len:4 = 1
+ 0x61, // 'a'
+ //
+ //
+ 0x06, // type:1 = i16
+ 0x00, 0x00, 0x00, 0x01, // count:4 = 1
+ 0x00, 0x01, // 1
+ //
+ //
+
+ // -
+ //
+ 0x00, 0x00, 0x00, 0x01, // len:4 = 1
+ 0x62, // 'b'
+ //
+ //
+ 0x06, // type:1 = i16
+ 0x00, 0x00, 0x00, 0x02, // count:4 = 2
+ 0x00, 0x02, // 2
+ 0x00, 0x03, // 3
+ //
+ //
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(tt *testing.T) {
+ checkStreamEncode(tt, test.value, test.encoded)
+ })
+ }
+}
+
func TestMapDecodeFailure(t *testing.T) {
tests := []failureTest{
{
@@ -462,6 +765,23 @@ func TestSet(t *testing.T) {
checkEncodeDecode(t, wire.TSet, tests)
}
+func TestSetBeginEncode(t *testing.T) {
+ var streamBuff bytes.Buffer
+ var err error
+
+ // Encode with Streaming protocol
+ w := binary.BorrowStreamWriter(&streamBuff)
+ err = w.WriteSetBegin(stream.SetHeader{
+ Type: wire.TList,
+ Length: 1,
+ })
+ assert.NoError(t, err)
+ binary.ReturnStreamWriter(w)
+
+ // Assert that encoded bytes are equivalent
+ assert.Equal(t, []byte{0xf, 0x0, 0x0, 0x0, 0x1}, streamBuff.Bytes())
+}
+
func TestSetDecodeFailure(t *testing.T) {
tests := []failureTest{
{
@@ -473,6 +793,22 @@ func TestSetDecodeFailure(t *testing.T) {
checkDecodeFailure(t, wire.TSet, tests)
}
+func TestStreamSet(t *testing.T) {
+ tests := []streamEncodeDecodeTest{
+ {"small set", testvalue(vset(wire.TBool)), []byte{0x02, 0x00, 0x00, 0x00, 0x00}},
+ {
+ "large set", testvalue(vset(wire.TBool, vbool(true), vbool(false), vbool(true))),
+ []byte{0x02, 0x00, 0x00, 0x00, 0x03, 0x01, 0x00, 0x01},
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(tt *testing.T) {
+ checkStreamEncode(tt, test.value, test.encoded)
+ })
+ }
+}
+
func TestSetEOFFailure(t *testing.T) {
tests := []failureTest{
{}, // empty
@@ -531,6 +867,77 @@ func TestList(t *testing.T) {
checkEncodeDecode(t, wire.TList, tests)
}
+func TestListBeginEncode(t *testing.T) {
+ var streamBuff bytes.Buffer
+ var err error
+
+ // Encode with Streaming protocol
+ w := binary.BorrowStreamWriter(&streamBuff)
+ err = w.WriteListBegin(stream.ListHeader{
+ Type: wire.TMap,
+ Length: 5,
+ })
+ assert.NoError(t, err)
+ binary.ReturnStreamWriter(w)
+
+ // Assert that encoded bytes are equivalent
+ assert.Equal(t, []byte{0xd, 0x0, 0x0, 0x0, 0x5}, streamBuff.Bytes())
+}
+
+func TestStreamList(t *testing.T) {
+ tests := []streamEncodeDecodeTest{
+ {"small list", testvalue(vlist(wire.TStruct)), []byte{0x0C, 0x00, 0x00, 0x00, 0x00}},
+ {
+ "large list",
+ testvalue(vlist(
+ wire.TStruct,
+ vstruct(
+ vfield(1, vi16(1)),
+ vfield(2, vi32(2)),
+ ),
+ vstruct(
+ vfield(1, vi16(3)),
+ vfield(2, vi32(4)),
+ ),
+ )),
+ []byte{
+ 0x0C, // vtype:1 = struct
+ 0x00, 0x00, 0x00, 0x02, // count:4 = 2
+
+ //
+ 0x06, // type:1 = i16
+ 0x00, 0x01, // id:2 = 1
+ 0x00, 0x01, // value = 1
+
+ 0x08, // type:1 = i32
+ 0x00, 0x02, // id:2 = 2
+ 0x00, 0x00, 0x00, 0x02, // value = 2
+
+ 0x00, // stop
+ //
+
+ //
+ 0x06, // type:1 = i16
+ 0x00, 0x01, // id:2 = 1
+ 0x00, 0x03, // value = 3
+
+ 0x08, // type:1 = i32
+ 0x00, 0x02, // id:2 = 2
+ 0x00, 0x00, 0x00, 0x04, // value = 4
+
+ 0x00, // stop
+ //
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(tt *testing.T) {
+ checkStreamEncode(tt, test.value, test.encoded)
+ })
+ }
+}
+
func TestListDecodeFailure(t *testing.T) {
tests := []failureTest{
{
@@ -642,6 +1049,97 @@ func TestStructOfContainers(t *testing.T) {
checkEncodeDecode(t, wire.TStruct, tests)
}
+func TestStreamStructOfContainers(t *testing.T) {
+ tests := []streamEncodeDecodeTest{
+ {
+ "struct of containers",
+ testvalue(vstruct(
+ vfield(1, vlist(
+ wire.TMap,
+ vmap(
+ wire.TI32, wire.TSet,
+ vitem(vi32(1), vset(
+ wire.TBinary,
+ vbinary("a"), vbinary("b"), vbinary("c"),
+ )),
+ vitem(vi32(2), vset(wire.TBinary)),
+ vitem(vi32(3), vset(
+ wire.TBinary,
+ vbinary("d"), vbinary("e"), vbinary("f"),
+ )),
+ ),
+ vmap(
+ wire.TI32, wire.TSet,
+ vitem(vi32(4), vset(wire.TBinary, vbinary("g"))),
+ ),
+ )),
+ vfield(2, vlist(wire.TI16, vi16(1), vi16(2), vi16(3))),
+ )),
+ []byte{
+ 0x0f, // type:list
+ 0x00, 0x01, // field ID 1
+
+ 0x0d, // type: map
+ 0x00, 0x00, 0x00, 0x02, // length: 2
+
+ //
+ 0x08, 0x0e, // ktype: i32, vtype: set
+ 0x00, 0x00, 0x00, 0x03, // length: 3
+
+ // 1: {"a", "b", "c"}
+ 0x00, 0x00, 0x00, 0x01, // 1
+ 0x0B, // type: binary
+ 0x00, 0x00, 0x00, 0x03, // length: 3
+ 0x00, 0x00, 0x00, 0x01, 0x61, // 'a'
+ 0x00, 0x00, 0x00, 0x01, 0x62, // 'b'
+ 0x00, 0x00, 0x00, 0x01, 0x63, // 'c'
+
+ // 2: {}
+ 0x00, 0x00, 0x00, 0x02, // 2
+ 0x0B, // type: binary
+ 0x00, 0x00, 0x00, 0x00, // length: 0
+
+ // 3: {"d", "e", "f"}
+ 0x00, 0x00, 0x00, 0x03, // 3
+ 0x0B, // type: binary
+ 0x00, 0x00, 0x00, 0x03, // length: 3
+ 0x00, 0x00, 0x00, 0x01, 0x64, // 'd'
+ 0x00, 0x00, 0x00, 0x01, 0x65, // 'e'
+ 0x00, 0x00, 0x00, 0x01, 0x66, // 'f'
+
+ //
+
+ //
+ 0x08, 0x0e, // ktype: i32, vtype: set
+ 0x00, 0x00, 0x00, 0x01, // length: 1
+
+ // 4: {"g"}
+ 0x00, 0x00, 0x00, 0x04, // 3
+ 0x0B, // type: binary
+ 0x00, 0x00, 0x00, 0x01, // length: 1
+ 0x00, 0x00, 0x00, 0x01, 0x67, // 'g'
+
+ //
+
+ 0x0f, // type: list
+ 0x00, 0x02, // field ID 2
+
+ 0x06, // type: i16
+ 0x00, 0x00, 0x00, 0x03, // length 3
+ 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, // [1,2,3]
+
+ 0x00,
+ },
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.description, func(tt *testing.T) {
+ checkStreamEncode(tt, test.value, test.encoded)
+ })
+ }
+}
+
func TestBinaryEnvelopeErrors(t *testing.T) {
tests := []struct {
encoded []byte
diff --git a/protocol/value_test.go b/protocol/value_test.go
index bbf9d33d..08477336 100644
--- a/protocol/value_test.go
+++ b/protocol/value_test.go
@@ -20,7 +20,12 @@
package protocol
-import "go.uber.org/thriftrw/wire"
+import (
+ "fmt"
+ "go.uber.org/thriftrw/protocol/binary"
+ "go.uber.org/thriftrw/protocol/stream"
+ "go.uber.org/thriftrw/wire"
+)
// This file doesn't actually contain any tests. It just contains helpers for
// constructing complex Value objects during protocol.
@@ -76,3 +81,166 @@ func vmap(kt, vt wire.Type, items ...wire.MapItem) wire.Value {
func vitem(k, v wire.Value) wire.MapItem {
return wire.MapItem{Key: k, Value: v}
}
+
+// testValue wraps around wire.Value to allow for stream encoding/decoding
+// testing.
+type testValue struct {
+ val wire.Value
+}
+
+func testvalue(val wire.Value) testValue {
+ return testValue{val: val}
+}
+
+func (v testValue) streamEncodeStruct(sw *binary.StreamWriter) error {
+ if v.val.Type() != wire.TStruct {
+ panic(fmt.Sprintf("Cannot call streamEncodeStruct on non-struct type"))
+ }
+
+ if err := sw.WriteStructBegin(); err != nil {
+ return err
+ }
+
+ for _, field := range v.val.GetStruct().Fields {
+ fm := stream.FieldHeader{
+ ID: field.ID,
+ Type: field.Value.Type(),
+ }
+ if err := sw.WriteFieldBegin(fm); err != nil {
+ return err
+ }
+
+ tv := testValue{val: field.Value}
+ if err := tv.StreamEncode(sw); err != nil {
+ return err
+ }
+
+ if err := sw.WriteFieldEnd(); err != nil {
+ return err
+ }
+ }
+
+ return sw.WriteStructEnd()
+}
+
+func (v testValue) streamEncodeMap(sw *binary.StreamWriter) error {
+ if v.val.Type() != wire.TMap {
+ panic(fmt.Sprintf("Cannot call streamEncodeMap on non-map type"))
+ }
+
+ mapItemList := v.val.GetMap()
+
+ mh := stream.MapHeader{
+ KeyType: mapItemList.KeyType(),
+ ValueType: mapItemList.ValueType(),
+ Length: mapItemList.Size(),
+ }
+ if err := sw.WriteMapBegin(mh); err != nil {
+ return err
+ }
+
+ err := mapItemList.ForEach(func(item wire.MapItem) error {
+ key := testValue{val: item.Key}
+ if err := key.StreamEncode(sw); err != nil {
+ return err
+ }
+
+ value := testValue{val: item.Value}
+ if err := value.StreamEncode(sw); err != nil {
+ return err
+ }
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ return sw.WriteMapEnd()
+}
+
+func (v testValue) streamEncodeSet(sw *binary.StreamWriter) error {
+ if v.val.Type() != wire.TSet {
+ panic(fmt.Sprintf("Cannot call streamEncodeSet on non-set type"))
+ }
+
+ valueList := v.val.GetSet()
+
+ sh := stream.SetHeader{
+ Length: valueList.Size(),
+ Type: valueList.ValueType(),
+ }
+ if err := sw.WriteSetBegin(sh); err != nil {
+ return err
+ }
+
+ err := valueList.ForEach(func(value wire.Value) error {
+ val := testValue{val: value}
+ if err := val.StreamEncode(sw); err != nil {
+ return err
+ }
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ return sw.WriteSetEnd()
+}
+
+func (v testValue) streamEncodeList(sw *binary.StreamWriter) error {
+ if v.val.Type() != wire.TList {
+ panic(fmt.Sprintf("Cannot call streamEncodeList on non-list type"))
+ }
+
+ valueList := v.val.GetList()
+
+ lh := stream.ListHeader{
+ Length: valueList.Size(),
+ Type: valueList.ValueType(),
+ }
+ if err := sw.WriteListBegin(lh); err != nil {
+ return err
+ }
+
+ err := valueList.ForEach(func(value wire.Value) error {
+ val := testValue{val: value}
+ if err := val.StreamEncode(sw); err != nil {
+ return err
+ }
+ return nil
+ })
+ if err != nil {
+ return err
+ }
+
+ return sw.WriteListEnd()
+}
+
+func (v testValue) StreamEncode(sw *binary.StreamWriter) error {
+ switch v.val.Type() {
+ case wire.TBool:
+ return sw.WriteBool(v.val.GetBool())
+ case wire.TI8:
+ return sw.WriteInt8(v.val.GetI8())
+ case wire.TDouble:
+ return sw.WriteDouble(v.val.GetDouble())
+ case wire.TI16:
+ return sw.WriteInt16(v.val.GetI16())
+ case wire.TI32:
+ return sw.WriteInt32(v.val.GetI32())
+ case wire.TI64:
+ return sw.WriteInt64(v.val.GetI64())
+ case wire.TBinary:
+ return sw.WriteBinary(v.val.GetBinary())
+ case wire.TStruct:
+ return v.streamEncodeStruct(sw)
+ case wire.TMap:
+ return v.streamEncodeMap(sw)
+ case wire.TSet:
+ return v.streamEncodeSet(sw)
+ case wire.TList:
+ return v.streamEncodeList(sw)
+ default:
+ panic(fmt.Sprintf("Unknown value type %v", v.val.Type()))
+ }
+}