diff --git a/protocol/binary/stream_writer.go b/protocol/binary/stream_writer.go new file mode 100644 index 00000000..76785e64 --- /dev/null +++ b/protocol/binary/stream_writer.go @@ -0,0 +1,250 @@ +// Copyright (c) 2021 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package binary + +import ( + "io" + "math" + "sync" + + "go.uber.org/thriftrw/internal/iface" + "go.uber.org/thriftrw/protocol/stream" +) + +var streamWriterPool = sync.Pool{New: func() interface{} { + writer := &StreamWriter{} + return writer +}} + +// StreamWriter implements basic logic for writing the Thrift Binary Protocol +// to an io.Writer. +type StreamWriter struct { + // Private implementation to disallow custom implementations of + // the Writer interface + iface.Impl + + writer io.Writer + + // This buffer is re-used every time we need a slice of up to 8 bytes. + buffer [8]byte +} + +// BorrowStreamWriter fetches a Writer from the system that will write its +// output to the given io.Writer. +// +// This StreamWriter must be returned back using ReturnStreamWriter. +func BorrowStreamWriter(w io.Writer) *StreamWriter { + writer := streamWriterPool.Get().(*StreamWriter) + writer.writer = w + return writer +} + +// ReturnStreamWriter returns a previously borrowed StreamWriter back to the +// system. +func ReturnStreamWriter(w *StreamWriter) { + w.writer = nil + streamWriterPool.Put(w) +} + +func (bw *StreamWriter) write(bs []byte) error { + _, err := bw.writer.Write(bs) + return err +} + +func (bw *StreamWriter) writeByte(b byte) error { + bs := bw.buffer[0:1] + bs[0] = b + return bw.write(bs) +} + +func (bw *StreamWriter) writeInt16(n int16) error { + bs := bw.buffer[0:2] + bigEndian.PutUint16(bs, uint16(n)) + return bw.write(bs) +} + +func (bw *StreamWriter) writeInt32(n int32) error { + bs := bw.buffer[0:4] + bigEndian.PutUint32(bs, uint32(n)) + return bw.write(bs) +} + +func (bw *StreamWriter) writeInt64(n int64) error { + bs := bw.buffer[0:8] + bigEndian.PutUint64(bs, uint64(n)) + return bw.write(bs) +} + +func (bw *StreamWriter) writeString(s string) error { + if err := bw.writeInt32(int32(len(s))); err != nil { + return err + } + + _, err := io.WriteString(bw.writer, s) + return err +} + +// WriteBool encodes a boolean +func (bw *StreamWriter) WriteBool(b bool) error { + if b { + return bw.writeByte(1) + } + return bw.writeByte(0) +} + +// WriteInt8 encodes an int8 +func (bw *StreamWriter) WriteInt8(i int8) error { + return bw.writeByte(byte(i)) +} + +// WriteInt16 encodes an int16 +func (bw *StreamWriter) WriteInt16(i int16) error { + return bw.writeInt16(i) +} + +// WriteInt32 encodes an int32 +func (bw *StreamWriter) WriteInt32(i int32) error { + return bw.writeInt32(i) +} + +// WriteInt64 encodes an int64 +func (bw *StreamWriter) WriteInt64(i int64) error { + return bw.writeInt64(i) +} + +// WriteString encodes a string +func (bw *StreamWriter) WriteString(s string) error { + return bw.writeString(s) +} + +// WriteDouble encodes a double +func (bw *StreamWriter) WriteDouble(d float64) error { + value := math.Float64bits(d) + return bw.writeInt64(int64(value)) +} + +// WriteBinary encodes binary +func (bw *StreamWriter) WriteBinary(b []byte) error { + if err := bw.writeInt32(int32(len(b))); err != nil { + return err + } + return bw.write(b) +} + +// WriteFieldBegin marks the beginning of a new field in a struct. The first +// byte denotes the type and the next two bytes denote the field id. +func (bw *StreamWriter) WriteFieldBegin(f stream.FieldHeader) error { + // type:1 + if err := bw.writeByte(byte(f.Type)); err != nil { + return err + } + + // id:2 + if err := bw.writeInt16(f.ID); err != nil { + return err + } + + return nil +} + +// WriteFieldEnd denotes the end of a field. No-op. +func (bw *StreamWriter) WriteFieldEnd() error { + return nil +} + +// WriteStructBegin denotes the beginning of a struct. No-op. +func (bw *StreamWriter) WriteStructBegin() error { + return nil +} + +// WriteStructEnd uses the zero byte to mark the end of a struct. +func (bw *StreamWriter) WriteStructEnd() error { + return bw.writeByte(0) // end struct +} + +// WriteListBegin marks the beginning of a new list. The first byte denotes +// the type of the items and the next four bytes denote the length of the list. +func (bw *StreamWriter) WriteListBegin(l stream.ListHeader) error { + // vtype:1 + if err := bw.writeByte(byte(l.Type)); err != nil { + return err + } + + // length:4 + if err := bw.writeInt32(int32(l.Length)); err != nil { + return err + } + + return nil +} + +// WriteListEnd marks the end of a list. No-op. +func (bw *StreamWriter) WriteListEnd() error { + return nil +} + +// WriteSetBegin marks the beginning of a new set. The first byte denotes +// the type of the items and the next four bytes denote the length of the set. +func (bw *StreamWriter) WriteSetBegin(s stream.SetHeader) error { + // vtype:1 + if err := bw.writeByte(byte(s.Type)); err != nil { + return err + } + + // length:4 + if err := bw.writeInt32(int32(s.Length)); err != nil { + return err + } + + return nil +} + +// WriteSetEnd marks the end of a set. No-op. +func (bw *StreamWriter) WriteSetEnd() error { + return nil +} + +// WriteMapBegin marks the beginning of a new map. The first byte denotes +// the type of the keys, the second byte denotes the type of the values, +// and the next four bytes denote the length of the map. +func (bw *StreamWriter) WriteMapBegin(m stream.MapHeader) error { + // ktype:1 + if err := bw.writeByte(byte(m.KeyType)); err != nil { + return err + } + + // vtype:1 + if err := bw.writeByte(byte(m.ValueType)); err != nil { + return err + } + + // length:4 + if err := bw.writeInt32(int32(m.Length)); err != nil { + return err + } + + return nil +} + +// WriteMapEnd marks the end of a map. No-op. +func (bw *StreamWriter) WriteMapEnd() error { + return nil +} diff --git a/protocol/binary_test.go b/protocol/binary_test.go index a01cff5f..612e18fe 100644 --- a/protocol/binary_test.go +++ b/protocol/binary_test.go @@ -23,6 +23,7 @@ package protocol import ( "bytes" "fmt" + "go.uber.org/thriftrw/protocol/stream" "io" "math" "reflect" @@ -40,6 +41,12 @@ type encodeDecodeTest struct { encoded []byte } +type streamEncodeDecodeTest struct { + description string + value testValue + encoded []byte +} + func checkEncodeDecode(t *testing.T, typ wire.Type, tests []encodeDecodeTest) { for _, tt := range tests { buffer := bytes.Buffer{} @@ -68,6 +75,18 @@ func checkEncodeDecode(t *testing.T, typ wire.Type, tests []encodeDecodeTest) { } } +func checkStreamEncode(t *testing.T, val testValue, expected []byte) { + var buff bytes.Buffer + + // Encode with Streaming protocol + w := binary.BorrowStreamWriter(&buff) + err := val.StreamEncode(w) + require.NoError(t, err) + binary.ReturnStreamWriter(w) + + assert.Equal(t, expected, buff.Bytes()) +} + type failureTest []byte func checkDecodeFailure(t *testing.T, typ wire.Type, tests []failureTest) { @@ -132,6 +151,19 @@ func TestBoolEOFFailure(t *testing.T) { checkEOFError(t, wire.TBool, tests) } +func TestStreamBool(t *testing.T) { + tests := []streamEncodeDecodeTest{ + {"false", testvalue(vbool(false)), []byte{0x00}}, + {"true", testvalue(vbool(true)), []byte{0x01}}, + } + + for _, test := range tests { + t.Run(test.description, func(tt *testing.T) { + checkStreamEncode(tt, test.value, test.encoded) + }) + } +} + func TestI8(t *testing.T) { tests := []encodeDecodeTest{ {vi8(0), []byte{0x00}}, @@ -152,6 +184,22 @@ func TestI8EOFFailure(t *testing.T) { checkEOFError(t, wire.TI8, tests) } +func TestStreamI8(t *testing.T) { + tests := []streamEncodeDecodeTest{ + {"0", testvalue(vi8(0)), []byte{0x00}}, + {"1", testvalue(vi8(1)), []byte{0x01}}, + {"-1", testvalue(vi8(-1)), []byte{0xff}}, + {"127", testvalue(vi8(127)), []byte{0x7f}}, + {"-128", testvalue(vi8(-128)), []byte{0x80}}, + } + + for _, test := range tests { + t.Run(test.description, func(tt *testing.T) { + checkStreamEncode(tt, test.value, test.encoded) + }) + } +} + func TestI16(t *testing.T) { tests := []encodeDecodeTest{ {vi16(1), []byte{0x00, 0x01}}, @@ -178,6 +226,27 @@ func TestI16EOFFailure(t *testing.T) { checkEOFError(t, wire.TI16, tests) } +func TestStreamI16(t *testing.T) { + tests := []streamEncodeDecodeTest{ + {"1", testvalue(vi16(1)), []byte{0x00, 0x01}}, + {"255", testvalue(vi16(255)), []byte{0x00, 0xff}}, + {"256", testvalue(vi16(256)), []byte{0x01, 0x00}}, + {"257", testvalue(vi16(257)), []byte{0x01, 0x01}}, + {"32767", testvalue(vi16(32767)), []byte{0x7f, 0xff}}, + {"-1", testvalue(vi16(-1)), []byte{0xff, 0xff}}, + {"-2", testvalue(vi16(-2)), []byte{0xff, 0xfe}}, + {"-256", testvalue(vi16(-256)), []byte{0xff, 0x00}}, + {"-255", testvalue(vi16(-255)), []byte{0xff, 0x01}}, + {"-32768", testvalue(vi16(-32768)), []byte{0x80, 0x00}}, + } + + for _, test := range tests { + t.Run(test.description, func(tt *testing.T) { + checkStreamEncode(tt, test.value, test.encoded) + }) + } +} + func TestI32(t *testing.T) { tests := []encodeDecodeTest{ {vi32(1), []byte{0x00, 0x00, 0x00, 0x01}}, @@ -204,6 +273,27 @@ func TestI32EOFFailure(t *testing.T) { checkEOFError(t, wire.TI32, tests) } +func TestStreamI32(t *testing.T) { + tests := []streamEncodeDecodeTest{ + {"1", testvalue(vi32(1)), []byte{0x00, 0x00, 0x00, 0x01}}, + {"255", testvalue(vi32(255)), []byte{0x00, 0x00, 0x00, 0xff}}, + {"65535", testvalue(vi32(65535)), []byte{0x00, 0x00, 0xff, 0xff}}, + {"16777215", testvalue(vi32(16777215)), []byte{0x00, 0xff, 0xff, 0xff}}, + {"2147483647", testvalue(vi32(2147483647)), []byte{0x7f, 0xff, 0xff, 0xff}}, + {"-1", testvalue(vi32(-1)), []byte{0xff, 0xff, 0xff, 0xff}}, + {"-256", testvalue(vi32(-256)), []byte{0xff, 0xff, 0xff, 0x00}}, + {"-65536", testvalue(vi32(-65536)), []byte{0xff, 0xff, 0x00, 0x00}}, + {"-16777216", testvalue(vi32(-16777216)), []byte{0xff, 0x00, 0x00, 0x00}}, + {"-2147483648", testvalue(vi32(-2147483648)), []byte{0x80, 0x00, 0x00, 0x00}}, + } + + for _, test := range tests { + t.Run(test.description, func(tt *testing.T) { + checkStreamEncode(tt, test.value, test.encoded) + }) + } +} + func TestI64(t *testing.T) { tests := []encodeDecodeTest{ {vi64(1), []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}}, @@ -232,6 +322,29 @@ func TestI64EOFFailure(t *testing.T) { checkEOFError(t, wire.TI64, tests) } +func TestStreamI64(t *testing.T) { + tests := []streamEncodeDecodeTest{ + {"1", testvalue(vi64(1)), []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01}}, + {"4294967295", testvalue(vi64(4294967295)), []byte{0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff}}, + {"1099511627775", testvalue(vi64(1099511627775)), []byte{0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {"281474976710655", testvalue(vi64(281474976710655)), []byte{0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {"72057594037927935", testvalue(vi64(72057594037927935)), []byte{0x00, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {"9223372036854775807", testvalue(vi64(9223372036854775807)), []byte{0x7f, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {"-1", testvalue(vi64(-1)), []byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + {"-4294967296", testvalue(vi64(-4294967296)), []byte{0xff, 0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00}}, + {"-1099511627776", testvalue(vi64(-1099511627776)), []byte{0xff, 0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {"-281474976710656", testvalue(vi64(-281474976710656)), []byte{0xff, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {"-72057594037927936", testvalue(vi64(-72057594037927936)), []byte{0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {"-9223372036854775808", testvalue(vi64(-9223372036854775808)), []byte{0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + } + + for _, test := range tests { + t.Run(test.description, func(tt *testing.T) { + checkStreamEncode(tt, test.value, test.encoded) + }) + } +} + func TestDouble(t *testing.T) { tests := []encodeDecodeTest{ {vdouble(0.0), []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, @@ -257,6 +370,26 @@ func TestDoubleEOFFailure(t *testing.T) { checkEOFError(t, wire.TDouble, tests) } +func TestStreamDouble(t *testing.T) { + tests := []streamEncodeDecodeTest{ + {"0.0", testvalue(vdouble(0.0)), []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {"1.0", testvalue(vdouble(1.0)), []byte{0x3f, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}}, + {"1.0000000001", testvalue(vdouble(1.0000000001)), []byte{0x3f, 0xf0, 0x0, 0x0, 0x0, 0x6, 0xdf, 0x38}}, + {"1.1", testvalue(vdouble(1.1)), []byte{0x3f, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a}}, + {"-1.1", testvalue(vdouble(-1.1)), []byte{0xbf, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a}}, + {"3.141592653589793", testvalue(vdouble(3.141592653589793)), []byte{0x40, 0x9, 0x21, 0xfb, 0x54, 0x44, 0x2d, 0x18}}, + {"-1.0000000001", testvalue(vdouble(-1.0000000001)), []byte{0xbf, 0xf0, 0x0, 0x0, 0x0, 0x6, 0xdf, 0x38}}, + {"0", testvalue(vdouble(math.Inf(0))), []byte{0x7f, 0xf0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}}, + {"-1", testvalue(vdouble(math.Inf(-1))), []byte{0xff, 0xf0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}}, + } + + for _, test := range tests { + t.Run(test.description, func(tt *testing.T) { + checkStreamEncode(tt, test.value, test.encoded) + }) + } +} + func TestDoubleNaN(t *testing.T) { value := vdouble(math.NaN()) encoded := []byte{0x7f, 0xf8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1} @@ -274,6 +407,13 @@ func TestDoubleNaN(t *testing.T) { } } +func TestStreamDoubleNaN(t *testing.T) { + value := testvalue(vdouble(math.NaN())) + encoded := []byte{0x7f, 0xf8, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1} + + checkStreamEncode(t, value, encoded) +} + func TestBinary(t *testing.T) { tests := []encodeDecodeTest{ {vbinary(""), []byte{0x00, 0x00, 0x00, 0x00}}, @@ -286,6 +426,22 @@ func TestBinary(t *testing.T) { checkEncodeDecode(t, wire.TBinary, tests) } +func TestStreamBinary(t *testing.T) { + tests := []streamEncodeDecodeTest{ + {"empty string", testvalue(vbinary("")), []byte{0x00, 0x00, 0x00, 0x00}}, + {"hello ", testvalue(vbinary("hello")), []byte{ + 0x00, 0x00, 0x00, 0x05, // len:4 = 5 + 0x68, 0x65, 0x6c, 0x6c, 0x6f, // 'h', 'e', 'l', 'l', 'o' + }}, + } + + for _, test := range tests { + t.Run(test.description, func(tt *testing.T) { + checkStreamEncode(tt, test.value, test.encoded) + }) + } +} + func TestBinaryLargeLength(t *testing.T) { // 5 MB + 4 bytes for length data := make([]byte, 5242880+4) @@ -376,6 +532,86 @@ func TestStruct(t *testing.T) { checkEncodeDecode(t, wire.TStruct, tests) } +func TestStructBeginAndEndEncode(t *testing.T) { + var streamBuff bytes.Buffer + var err error + + // Encode with Streaming protocol + w := binary.BorrowStreamWriter(&streamBuff) + err = w.WriteStructBegin() + assert.NoError(t, err) + err = w.WriteStructEnd() + assert.NoError(t, err) + binary.ReturnStreamWriter(w) + + // Assert that encoded bytes are equivalent + assert.Equal(t, []byte{0x0}, streamBuff.Bytes()) +} + +func TestStreamStruct(t *testing.T) { + tests := []streamEncodeDecodeTest{ + {"empty struct", testvalue(vstruct()), []byte{0x00}}, + {"simple struct", testvalue(vstruct(vfield(1, vbool(true)))), []byte{ + 0x02, // type:1 = bool + 0x00, 0x01, // id:2 = 1 + 0x01, // value = true + 0x00, // stop + }}, + { + "complex struct", + testvalue(vstruct( + vfield(1, vi16(42)), + vfield(2, vlist(wire.TBinary, vbinary("foo"), vbinary("bar"))), + vfield(3, vset(wire.TBinary, vbinary("baz"), vbinary("qux"))), + )), []byte{ + 0x06, // type:1 = i16 + 0x00, 0x01, // id:2 = 1 + 0x00, 0x2a, // value = 42 + + 0x0F, // type:1 = list + 0x00, 0x02, // id:2 = 2 + + // + 0x0B, // type:1 = binary + 0x00, 0x00, 0x00, 0x02, // size:4 = 2 + // + 0x00, 0x00, 0x00, 0x03, // len:4 = 3 + 0x66, 0x6f, 0x6f, // 'f', 'o', 'o' + // + // + 0x00, 0x00, 0x00, 0x03, // len:4 = 3 + 0x62, 0x61, 0x72, // 'b', 'a', 'r' + // + // + + 0x0E, // type = set + 0x00, 0x03, // id = 3 + + // + 0x0B, // type:1 = binary + 0x00, 0x00, 0x00, 0x02, // size:4 = 2 + // + 0x00, 0x00, 0x00, 0x03, // len:4 = 3 + 0x62, 0x61, 0x7a, // 'b', 'a', 'z' + // + // + 0x00, 0x00, 0x00, 0x03, // len:4 = 3 + 0x71, 0x75, 0x78, // 'q', 'u', 'x' + // + // + + 0x00, // stop + }, + }, + } + + for _, test := range tests { + t.Run(test.description, func(tt *testing.T) { + checkStreamEncode(tt, test.value, test.encoded) + }) + } +} + func TestStructEOFFailure(t *testing.T) { tests := []failureTest{ {}, @@ -430,6 +666,73 @@ func TestMap(t *testing.T) { checkEncodeDecode(t, wire.TMap, tests) } +func TestMapBeginEncode(t *testing.T) { + var streamBuff bytes.Buffer + var err error + + // Encode with Streaming protocol + w := binary.BorrowStreamWriter(&streamBuff) + err = w.WriteMapBegin(stream.MapHeader{ + KeyType: wire.TBinary, + ValueType: wire.TBool, + Length: 1, + }) + assert.NoError(t, err) + binary.ReturnStreamWriter(w) + + // Assert that encoded bytes are equivalent + assert.Equal(t, []byte{0xb, 0x2, 0x0, 0x0, 0x0, 0x1}, streamBuff.Bytes()) +} + +func TestStreamMap(t *testing.T) { + tests := []streamEncodeDecodeTest{ + {"small map", testvalue(vmap(wire.TI64, wire.TBinary)), []byte{0x0A, 0x0B, 0x00, 0x00, 0x00, 0x00}}, + { + "complex map", + testvalue(vmap( + wire.TBinary, wire.TList, + vitem(vbinary("a"), vlist(wire.TI16, vi16(1))), + vitem(vbinary("b"), vlist(wire.TI16, vi16(2), vi16(3))), + )), []byte{ + 0x0B, // ktype = binary + 0x0F, // vtype = list + 0x00, 0x00, 0x00, 0x02, // count:4 = 2 + + // + // + 0x00, 0x00, 0x00, 0x01, // len:4 = 1 + 0x61, // 'a' + // + // + 0x06, // type:1 = i16 + 0x00, 0x00, 0x00, 0x01, // count:4 = 1 + 0x00, 0x01, // 1 + // + // + + // + // + 0x00, 0x00, 0x00, 0x01, // len:4 = 1 + 0x62, // 'b' + // + // + 0x06, // type:1 = i16 + 0x00, 0x00, 0x00, 0x02, // count:4 = 2 + 0x00, 0x02, // 2 + 0x00, 0x03, // 3 + // + // + }, + }, + } + + for _, test := range tests { + t.Run(test.description, func(tt *testing.T) { + checkStreamEncode(tt, test.value, test.encoded) + }) + } +} + func TestMapDecodeFailure(t *testing.T) { tests := []failureTest{ { @@ -462,6 +765,23 @@ func TestSet(t *testing.T) { checkEncodeDecode(t, wire.TSet, tests) } +func TestSetBeginEncode(t *testing.T) { + var streamBuff bytes.Buffer + var err error + + // Encode with Streaming protocol + w := binary.BorrowStreamWriter(&streamBuff) + err = w.WriteSetBegin(stream.SetHeader{ + Type: wire.TList, + Length: 1, + }) + assert.NoError(t, err) + binary.ReturnStreamWriter(w) + + // Assert that encoded bytes are equivalent + assert.Equal(t, []byte{0xf, 0x0, 0x0, 0x0, 0x1}, streamBuff.Bytes()) +} + func TestSetDecodeFailure(t *testing.T) { tests := []failureTest{ { @@ -473,6 +793,22 @@ func TestSetDecodeFailure(t *testing.T) { checkDecodeFailure(t, wire.TSet, tests) } +func TestStreamSet(t *testing.T) { + tests := []streamEncodeDecodeTest{ + {"small set", testvalue(vset(wire.TBool)), []byte{0x02, 0x00, 0x00, 0x00, 0x00}}, + { + "large set", testvalue(vset(wire.TBool, vbool(true), vbool(false), vbool(true))), + []byte{0x02, 0x00, 0x00, 0x00, 0x03, 0x01, 0x00, 0x01}, + }, + } + + for _, test := range tests { + t.Run(test.description, func(tt *testing.T) { + checkStreamEncode(tt, test.value, test.encoded) + }) + } +} + func TestSetEOFFailure(t *testing.T) { tests := []failureTest{ {}, // empty @@ -531,6 +867,77 @@ func TestList(t *testing.T) { checkEncodeDecode(t, wire.TList, tests) } +func TestListBeginEncode(t *testing.T) { + var streamBuff bytes.Buffer + var err error + + // Encode with Streaming protocol + w := binary.BorrowStreamWriter(&streamBuff) + err = w.WriteListBegin(stream.ListHeader{ + Type: wire.TMap, + Length: 5, + }) + assert.NoError(t, err) + binary.ReturnStreamWriter(w) + + // Assert that encoded bytes are equivalent + assert.Equal(t, []byte{0xd, 0x0, 0x0, 0x0, 0x5}, streamBuff.Bytes()) +} + +func TestStreamList(t *testing.T) { + tests := []streamEncodeDecodeTest{ + {"small list", testvalue(vlist(wire.TStruct)), []byte{0x0C, 0x00, 0x00, 0x00, 0x00}}, + { + "large list", + testvalue(vlist( + wire.TStruct, + vstruct( + vfield(1, vi16(1)), + vfield(2, vi32(2)), + ), + vstruct( + vfield(1, vi16(3)), + vfield(2, vi32(4)), + ), + )), + []byte{ + 0x0C, // vtype:1 = struct + 0x00, 0x00, 0x00, 0x02, // count:4 = 2 + + // + 0x06, // type:1 = i16 + 0x00, 0x01, // id:2 = 1 + 0x00, 0x01, // value = 1 + + 0x08, // type:1 = i32 + 0x00, 0x02, // id:2 = 2 + 0x00, 0x00, 0x00, 0x02, // value = 2 + + 0x00, // stop + // + + // + 0x06, // type:1 = i16 + 0x00, 0x01, // id:2 = 1 + 0x00, 0x03, // value = 3 + + 0x08, // type:1 = i32 + 0x00, 0x02, // id:2 = 2 + 0x00, 0x00, 0x00, 0x04, // value = 4 + + 0x00, // stop + // + }, + }, + } + + for _, test := range tests { + t.Run(test.description, func(tt *testing.T) { + checkStreamEncode(tt, test.value, test.encoded) + }) + } +} + func TestListDecodeFailure(t *testing.T) { tests := []failureTest{ { @@ -642,6 +1049,97 @@ func TestStructOfContainers(t *testing.T) { checkEncodeDecode(t, wire.TStruct, tests) } +func TestStreamStructOfContainers(t *testing.T) { + tests := []streamEncodeDecodeTest{ + { + "struct of containers", + testvalue(vstruct( + vfield(1, vlist( + wire.TMap, + vmap( + wire.TI32, wire.TSet, + vitem(vi32(1), vset( + wire.TBinary, + vbinary("a"), vbinary("b"), vbinary("c"), + )), + vitem(vi32(2), vset(wire.TBinary)), + vitem(vi32(3), vset( + wire.TBinary, + vbinary("d"), vbinary("e"), vbinary("f"), + )), + ), + vmap( + wire.TI32, wire.TSet, + vitem(vi32(4), vset(wire.TBinary, vbinary("g"))), + ), + )), + vfield(2, vlist(wire.TI16, vi16(1), vi16(2), vi16(3))), + )), + []byte{ + 0x0f, // type:list + 0x00, 0x01, // field ID 1 + + 0x0d, // type: map + 0x00, 0x00, 0x00, 0x02, // length: 2 + + // + 0x08, 0x0e, // ktype: i32, vtype: set + 0x00, 0x00, 0x00, 0x03, // length: 3 + + // 1: {"a", "b", "c"} + 0x00, 0x00, 0x00, 0x01, // 1 + 0x0B, // type: binary + 0x00, 0x00, 0x00, 0x03, // length: 3 + 0x00, 0x00, 0x00, 0x01, 0x61, // 'a' + 0x00, 0x00, 0x00, 0x01, 0x62, // 'b' + 0x00, 0x00, 0x00, 0x01, 0x63, // 'c' + + // 2: {} + 0x00, 0x00, 0x00, 0x02, // 2 + 0x0B, // type: binary + 0x00, 0x00, 0x00, 0x00, // length: 0 + + // 3: {"d", "e", "f"} + 0x00, 0x00, 0x00, 0x03, // 3 + 0x0B, // type: binary + 0x00, 0x00, 0x00, 0x03, // length: 3 + 0x00, 0x00, 0x00, 0x01, 0x64, // 'd' + 0x00, 0x00, 0x00, 0x01, 0x65, // 'e' + 0x00, 0x00, 0x00, 0x01, 0x66, // 'f' + + // + + // + 0x08, 0x0e, // ktype: i32, vtype: set + 0x00, 0x00, 0x00, 0x01, // length: 1 + + // 4: {"g"} + 0x00, 0x00, 0x00, 0x04, // 3 + 0x0B, // type: binary + 0x00, 0x00, 0x00, 0x01, // length: 1 + 0x00, 0x00, 0x00, 0x01, 0x67, // 'g' + + // + + 0x0f, // type: list + 0x00, 0x02, // field ID 2 + + 0x06, // type: i16 + 0x00, 0x00, 0x00, 0x03, // length 3 + 0x00, 0x01, 0x00, 0x02, 0x00, 0x03, // [1,2,3] + + 0x00, + }, + }, + } + + for _, test := range tests { + t.Run(test.description, func(tt *testing.T) { + checkStreamEncode(tt, test.value, test.encoded) + }) + } +} + func TestBinaryEnvelopeErrors(t *testing.T) { tests := []struct { encoded []byte diff --git a/protocol/value_test.go b/protocol/value_test.go index bbf9d33d..08477336 100644 --- a/protocol/value_test.go +++ b/protocol/value_test.go @@ -20,7 +20,12 @@ package protocol -import "go.uber.org/thriftrw/wire" +import ( + "fmt" + "go.uber.org/thriftrw/protocol/binary" + "go.uber.org/thriftrw/protocol/stream" + "go.uber.org/thriftrw/wire" +) // This file doesn't actually contain any tests. It just contains helpers for // constructing complex Value objects during protocol. @@ -76,3 +81,166 @@ func vmap(kt, vt wire.Type, items ...wire.MapItem) wire.Value { func vitem(k, v wire.Value) wire.MapItem { return wire.MapItem{Key: k, Value: v} } + +// testValue wraps around wire.Value to allow for stream encoding/decoding +// testing. +type testValue struct { + val wire.Value +} + +func testvalue(val wire.Value) testValue { + return testValue{val: val} +} + +func (v testValue) streamEncodeStruct(sw *binary.StreamWriter) error { + if v.val.Type() != wire.TStruct { + panic(fmt.Sprintf("Cannot call streamEncodeStruct on non-struct type")) + } + + if err := sw.WriteStructBegin(); err != nil { + return err + } + + for _, field := range v.val.GetStruct().Fields { + fm := stream.FieldHeader{ + ID: field.ID, + Type: field.Value.Type(), + } + if err := sw.WriteFieldBegin(fm); err != nil { + return err + } + + tv := testValue{val: field.Value} + if err := tv.StreamEncode(sw); err != nil { + return err + } + + if err := sw.WriteFieldEnd(); err != nil { + return err + } + } + + return sw.WriteStructEnd() +} + +func (v testValue) streamEncodeMap(sw *binary.StreamWriter) error { + if v.val.Type() != wire.TMap { + panic(fmt.Sprintf("Cannot call streamEncodeMap on non-map type")) + } + + mapItemList := v.val.GetMap() + + mh := stream.MapHeader{ + KeyType: mapItemList.KeyType(), + ValueType: mapItemList.ValueType(), + Length: mapItemList.Size(), + } + if err := sw.WriteMapBegin(mh); err != nil { + return err + } + + err := mapItemList.ForEach(func(item wire.MapItem) error { + key := testValue{val: item.Key} + if err := key.StreamEncode(sw); err != nil { + return err + } + + value := testValue{val: item.Value} + if err := value.StreamEncode(sw); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + return sw.WriteMapEnd() +} + +func (v testValue) streamEncodeSet(sw *binary.StreamWriter) error { + if v.val.Type() != wire.TSet { + panic(fmt.Sprintf("Cannot call streamEncodeSet on non-set type")) + } + + valueList := v.val.GetSet() + + sh := stream.SetHeader{ + Length: valueList.Size(), + Type: valueList.ValueType(), + } + if err := sw.WriteSetBegin(sh); err != nil { + return err + } + + err := valueList.ForEach(func(value wire.Value) error { + val := testValue{val: value} + if err := val.StreamEncode(sw); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + return sw.WriteSetEnd() +} + +func (v testValue) streamEncodeList(sw *binary.StreamWriter) error { + if v.val.Type() != wire.TList { + panic(fmt.Sprintf("Cannot call streamEncodeList on non-list type")) + } + + valueList := v.val.GetList() + + lh := stream.ListHeader{ + Length: valueList.Size(), + Type: valueList.ValueType(), + } + if err := sw.WriteListBegin(lh); err != nil { + return err + } + + err := valueList.ForEach(func(value wire.Value) error { + val := testValue{val: value} + if err := val.StreamEncode(sw); err != nil { + return err + } + return nil + }) + if err != nil { + return err + } + + return sw.WriteListEnd() +} + +func (v testValue) StreamEncode(sw *binary.StreamWriter) error { + switch v.val.Type() { + case wire.TBool: + return sw.WriteBool(v.val.GetBool()) + case wire.TI8: + return sw.WriteInt8(v.val.GetI8()) + case wire.TDouble: + return sw.WriteDouble(v.val.GetDouble()) + case wire.TI16: + return sw.WriteInt16(v.val.GetI16()) + case wire.TI32: + return sw.WriteInt32(v.val.GetI32()) + case wire.TI64: + return sw.WriteInt64(v.val.GetI64()) + case wire.TBinary: + return sw.WriteBinary(v.val.GetBinary()) + case wire.TStruct: + return v.streamEncodeStruct(sw) + case wire.TMap: + return v.streamEncodeMap(sw) + case wire.TSet: + return v.streamEncodeSet(sw) + case wire.TList: + return v.streamEncodeList(sw) + default: + panic(fmt.Sprintf("Unknown value type %v", v.val.Type())) + } +}