Skip to content

Commit

Permalink
Provide "decode" code generation for the streaming variants for all o…
Browse files Browse the repository at this point in the history
…ther types

Provide the streaming read generators for all of the other types.

Doing so meant providing a `StreamGenerator` that plumbs into the existing
generators.  This `StreamGenerator` only provides the "decoding" mechanism,
leveraging #491 for reading the raw Thrift encoding off of the wire.

In addition to providing a `StreamGenerator`, the list, map, set, enum, typdef,
and struct generators all added a `Decoder` method that will appropriately
recurse and iterate to generate the proper, mirrored readers for the raw wire
representation.

The `Decode` and `DecodePtr` methods for `StreamGenerator` then hook into the
templated generator itself, providing the templated `decode` and `decodePtr`
calls that will be used where necessary.

The existing tests were leveraged to make sure that the streaming reads were
compatible with the current binary writes and that no data was lost.  In doing
so, a 'genericized' function that will perform the cross-products of
encoding/writing and decoding/reading.

A benchmark was also added to evaluate the new streaming reads.
```
name \ time/op                                     old.txt      new.txt       stream.txt
RoundTrip/PrimitiveOptionalStruct/Encode-8         1.72µs ± 5%   1.72µs ± 0%   1.76µs ± 6%
RoundTrip/PrimitiveOptionalStruct/Decode-8         2.47µs ± 1%   2.75µs ± 0%   2.68µs ± 1%
RoundTrip/Graph/Encode-8                           3.18µs ± 2%   3.13µs ± 1%   3.13µs ± 1%
RoundTrip/Graph/Decode-8                           5.02µs ± 2%   8.34µs ± 2%   8.35µs ± 2%
RoundTrip/ContainersOfContainers/Encode-8          19.5µs ± 3%   19.0µs ± 2%   19.3µs ± 2%
RoundTrip/ContainersOfContainers/Decode-8          46.8µs ± 5%  104.8µs ± 1%  106.7µs ± 2%
RoundTrip/PrimitiveOptionalStruct/StreamingRead-8                              1.09µs ± 1%
RoundTrip/Graph/StreamingRead-8                                                1.69µs ± 4%
RoundTrip/ContainersOfContainers/StreamingRead-8                               25.3µs ± 2%

name \ alloc/op                                    old.txt      new.txt       stream.txt
RoundTrip/PrimitiveOptionalStruct/Encode-8           704B ± 0%     704B ± 0%     704B ± 0%
RoundTrip/PrimitiveOptionalStruct/Decode-8         1.40kB ± 0%   1.46kB ± 0%   1.46kB ± 0%
RoundTrip/Graph/Encode-8                           1.70kB ± 0%   1.70kB ± 0%   1.70kB ± 0%
RoundTrip/Graph/Decode-8                           2.78kB ± 0%   3.57kB ± 0%   3.57kB ± 0%
RoundTrip/ContainersOfContainers/Encode-8          1.30kB ± 0%   1.30kB ± 0%   1.30kB ± 0%
RoundTrip/ContainersOfContainers/Decode-8          12.3kB ± 0%   28.6kB ± 0%   28.6kB ± 0%
RoundTrip/PrimitiveOptionalStruct/StreamingRead-8                                104B ± 0%
RoundTrip/Graph/StreamingRead-8                                                  216B ± 0%
RoundTrip/ContainersOfContainers/StreamingRead-8                               10.2kB ± 0%

name \ allocs/op                                   old.txt      new.txt       stream.txt
RoundTrip/PrimitiveOptionalStruct/Encode-8           1.00 ± 0%     1.00 ± 0%     1.00 ± 0%
RoundTrip/PrimitiveOptionalStruct/Decode-8           14.0 ± 0%     15.0 ± 0%     15.0 ± 0%
RoundTrip/Graph/Encode-8                             11.0 ± 0%     11.0 ± 0%     11.0 ± 0%
RoundTrip/Graph/Decode-8                             32.0 ± 0%     63.0 ± 0%     63.0 ± 0%
RoundTrip/ContainersOfContainers/Encode-8            18.0 ± 0%     18.0 ± 0%     18.0 ± 0%
RoundTrip/ContainersOfContainers/Decode-8             164 ± 0%      837 ± 0%      837 ± 0%
RoundTrip/PrimitiveOptionalStruct/StreamingRead-8                                11.0 ± 0%
RoundTrip/Graph/StreamingRead-8                                                  11.0 ± 0%
RoundTrip/ContainersOfContainers/StreamingRead-8                                  147 ± 0%
```
"old" represents the original code, "new" represents the binary decoding
utilizing the streaming decoder, and "stream" represents the benchmarks as it
stands at this diff.
  • Loading branch information
witriew committed Jun 16, 2021
1 parent a7893a9 commit 6fa1b7c
Show file tree
Hide file tree
Showing 35 changed files with 9,448 additions and 379 deletions.
31 changes: 31 additions & 0 deletions gen/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
tc "go.uber.org/thriftrw/gen/internal/tests/containers"
ts "go.uber.org/thriftrw/gen/internal/tests/structs"
"go.uber.org/thriftrw/protocol"
"go.uber.org/thriftrw/protocol/stream"
"go.uber.org/thriftrw/ptr"
"go.uber.org/thriftrw/wire"
)
Expand All @@ -18,6 +19,13 @@ type thriftType interface {
FromWire(wire.Value) error
}

type streamingThriftType interface {
thriftType

// Encode(stream.Writer) error
Decode(stream.Reader) error
}

func BenchmarkRoundTrip(b *testing.B) {
type benchCase struct {
name string
Expand Down Expand Up @@ -138,6 +146,25 @@ func BenchmarkRoundTrip(b *testing.B) {
}
}

benchmarkStreamingRead := func(b *testing.B, bb benchCase) {
var buff bytes.Buffer
w, err := bb.give.ToWire()
require.NoError(b, err, "ToWire")
require.NoError(b, protocol.Binary.Encode(w, &buff), "Encode")

r := bytes.NewReader(buff.Bytes())
give, ok := bb.give.(streamingThriftType)
require.True(b, ok)

b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Seek(0, 0)

reader := protocol.BinaryStreamer.Reader(r)
require.NoError(b, give.Decode(reader), "Decode")
}
}

for _, bb := range benchmarks {
b.Run(bb.name, func(b *testing.B) {
b.Run("Encode", func(b *testing.B) {
Expand All @@ -147,6 +174,10 @@ func BenchmarkRoundTrip(b *testing.B) {
b.Run("Decode", func(b *testing.B) {
benchmarkDecode(b, bb)
})

b.Run("StreamingRead", func(b *testing.B) {
benchmarkStreamingRead(b, bb)
})
})
}
}
Expand Down
12 changes: 12 additions & 0 deletions gen/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ func TestCollectionsOfPrimitives(t *testing.T) {
for _, tt := range tests {
assertRoundTrip(t, &tt.p, tt.v, tt.desc)
assert.True(t, tt.p.Equals(&tt.p), tt.desc)

testRoundTripCombos(t, &tt.p, tt.v, tt.desc)
assert.True(t, tt.p.Equals(&tt.p), tt.desc)
}
}

Expand Down Expand Up @@ -351,6 +354,9 @@ func TestEnumContainers(t *testing.T) {
for _, tt := range tests {
assertRoundTrip(t, &tt.r, tt.v, "EnumContainers")
assert.True(t, tt.r.Equals(&tt.r), "EnumContainers equal")

testRoundTripCombos(t, &tt.r, tt.v, "EnumContainers")
assert.True(t, tt.r.Equals(&tt.r), "EnumContainers equal")
}
}

Expand Down Expand Up @@ -506,6 +512,9 @@ func TestListOfStructs(t *testing.T) {
for _, tt := range tests {
assertRoundTrip(t, &tt.r, tt.v, "Graph")
assert.True(t, tt.r.Equals(&tt.r), "Graph equal")

testRoundTripCombos(t, &tt.r, tt.v, "Graph")
assert.True(t, tt.r.Equals(&tt.r), "Graph equal")
}
}

Expand Down Expand Up @@ -949,6 +958,9 @@ func TestCrazyTown(t *testing.T) {
for _, tt := range tests {
assertRoundTrip(t, &tt.x, tt.v, tt.desc)
assert.True(t, tt.x.Equals(&tt.x), tt.desc)

testRoundTripCombos(t, &tt.x, tt.v, tt.desc)
assert.True(t, tt.x.Equals(&tt.x), tt.desc)
}
}

Expand Down
23 changes: 23 additions & 0 deletions gen/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,29 @@ func (e *enumGenerator) Reader(g Generator, spec *compile.EnumSpec) (string, err
return name, wrapGenerateError(spec.ThriftName(), err)
}

func (e *enumGenerator) Decoder(g Generator, spec *compile.EnumSpec) (string, error) {
name := decoderFuncName(g, spec)
err := g.EnsureDeclared(
`
<$stream := import "go.uber.org/thriftrw/protocol/stream">
<$v := newVar "v">
<$sr := newVar "sr">
func <.Name>(<$sr> <$stream>.Reader) (<typeName .Spec>, error) {
var <$v> <typeName .Spec>
err := <$v>.Decode(<$sr>)
return <$v>, err
}
`,
struct {
Name string
Spec *compile.EnumSpec
}{Name: name, Spec: spec},
)

return name, wrapGenerateError(spec.ThriftName(), err)
}

func enum(g Generator, spec *compile.EnumSpec) error {
if err := verifyUniqueEnumItemLabels(spec); err != nil {
return err
Expand Down
11 changes: 7 additions & 4 deletions gen/enum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ func TestEnumDefaultWire(t *testing.T) {

for _, tt := range tests {
assertRoundTrip(t, &tt.e, tt.v, "EnumDefault")
assertStreamingRoundTrip(t, &tt.e, tt.v, "EnumDefault")

testRoundTripCombos(t, &tt.e, tt.v, "EnumDefault")
}
}

Expand Down Expand Up @@ -155,7 +156,8 @@ func TestEnumWithDuplicateValuesWire(t *testing.T) {

for _, tt := range tests {
assertRoundTrip(t, &tt.e, tt.v, "EnumWithDuplicateValues")
assertStreamingRoundTrip(t, &tt.e, tt.v, "EnumWithDuplicateValues")

testRoundTripCombos(t, &tt.e, tt.v, "EnumWithDuplicateValues")
}
}

Expand Down Expand Up @@ -187,6 +189,7 @@ func TestOptionalEnum(t *testing.T) {

for _, tt := range tests {
assertRoundTrip(t, &tt.s, tt.v, "StructWithOptionalEnum")
testRoundTripCombos(t, &tt.s, tt.v, "StructWithOptionalEnum")
}
}

Expand Down Expand Up @@ -498,9 +501,9 @@ func TestEnumLabelValid(t *testing.T) {
t.Run("wire", func(t *testing.T) {
assertRoundTrip(t, &tt.item, wire.NewValueI32(int32(tt.item)),
"%v", tt.item)
assertStreamingRoundTrip(t, &tt.item, wire.NewValueI32(int32(tt.item)),
"%v", tt.item)

testRoundTripCombos(t, &tt.item, wire.NewValueI32(int32(tt.item)),
tt.item.String())
})
})
}
Expand Down
112 changes: 112 additions & 0 deletions gen/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const (
var reservedIdentifiers = map[string]struct{}{
"ToWire": {},
"FromWire": {},
"Decode": {},
"String": {},
"Equals": {},
}
Expand Down Expand Up @@ -95,6 +96,10 @@ func (f fieldGroupGenerator) Generate(g Generator) error {
return err
}

if err := f.Decode(g); err != nil {
return err
}

if err := f.String(g); err != nil {
return err
}
Expand Down Expand Up @@ -467,6 +472,113 @@ func (f fieldGroupGenerator) FromWire(g Generator) error {
`, f, TemplateFunc("constantValuePtr", ConstantValuePtr))
}

func (f fieldGroupGenerator) Decode(g Generator) error {
return g.DeclareFromTemplate(
`
<$stream := import "go.uber.org/thriftrw/protocol/stream">
<$v := newVar "v">
<$sr := newVar "sr">
// Decode deserializes a <.Name> struct directly from its Thrift-level
// representation, without going through an intemediary type.
//
// An error is returned if a <.Name> struct could not be generated from the wire
// representation.
func (<$v> *<.Name>) Decode(<$sr> <$stream>.Reader) error {
<$isSet := newNamespace>
<range .Fields>
<- if .Required ->
<$isSet.NewName (printf "%sIsSet" .Name)> := false
<- end>
<end>
if err := <$sr>.ReadStructBegin(); err != nil {
return err
}
<$fh := newVar "fh">
<$ok := newVar "ok">
<$fh>, <$ok>, err := <$sr>.ReadFieldBegin()
if err != nil {
return err
}
for <$ok> {
switch <$fh>.ID {
<range .Fields ->
case <.ID>:
if <$fh>.Type == <typeCode .Type> {
<- $lhs := printf "%s.%s" $v (goName .) ->
<- if .Required ->
<$lhs>, err = <decode .Type $sr>
<- else ->
<decodePtr .Type $lhs $sr>
<- end>
if err != nil {
return err
}
<if .Required ->
<$isSet.Rotate (printf "%sIsSet" .Name)> = true
<- end>
}
<end ->
}
<$fh>, <$ok>, err = <$sr>.ReadFieldBegin()
if err != nil {
return err
}
}
if err := <$sr>.ReadFieldEnd(); err != nil {
return err
}
if err := <$sr>.ReadStructEnd(); err != nil {
return err
}
<$structName := .Name>
<range .Fields>
<$fname := goName .>
<$f := printf "%s.%s" $v $fname>
<if isNotNil .Default>
if <$f> == nil {
<$f> = <constantValuePtr .Default .Type>
}
<else>
<if .Required>
if !<$isSet.Rotate (printf "%sIsSet" .Name)> {
return <import "errors">.New("field <$fname> of <$structName> is required")
}
<end>
<end>
<end>
<if and .IsUnion (len .Fields)>
<$fmt := import "fmt">
<$count := newVar "count">
<$count> := 0
<range .Fields ->
if <$v>.<goName .> != nil {
<$count>++
}
<end>
<- if .AllowEmptyUnion ->
if <$count> > 1 {
return <$fmt>.Errorf( "<.Name> should have at most one field: got %v fields", <$count>)
}
<- else ->
if <$count> != 1 {
return <$fmt>.Errorf( "<.Name> should have exactly one field: got %v fields", <$count>)
}
<- end>
<end>
return nil
}
`, f, TemplateFunc("constantValuePtr", ConstantValuePtr))
}

func (f fieldGroupGenerator) String(g Generator) error {
return g.DeclareFromTemplate(
`
Expand Down
3 changes: 3 additions & 0 deletions gen/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ type generator struct {
ImportPath string

w WireGenerator
s StreamGenerator
e equalsGenerator
z zapGenerator
noZap bool
Expand Down Expand Up @@ -235,6 +236,8 @@ func (g *generator) TextTemplate(s string, data interface{}, opts ...TemplateOpt
"fromWirePtr": curryGenerator(g.w.FromWirePtr, g),
"toWire": curryGenerator(g.w.ToWire, g),
"toWirePtr": curryGenerator(g.w.ToWirePtr, g),
"decode": curryGenerator(g.s.Decode, g),
"decodePtr": curryGenerator(g.s.DecodePtr, g),
"typeCode": curryGenerator(TypeCode, g),
"equals": curryGenerator(g.e.Equals, g),
"equalsPtr": curryGenerator(g.e.EqualsPtr, g),
Expand Down
Loading

0 comments on commit 6fa1b7c

Please sign in to comment.