Skip to content

Commit

Permalink
Add code generation for all wire types for stream encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
dianale31 committed Jun 22, 2021
1 parent 5075744 commit d19e892
Show file tree
Hide file tree
Showing 46 changed files with 6,228 additions and 47 deletions.
3 changes: 2 additions & 1 deletion ast/mock_visitor_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions gen/benchmark_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
tc "go.uber.org/thriftrw/gen/internal/tests/containers"
ts "go.uber.org/thriftrw/gen/internal/tests/structs"
"go.uber.org/thriftrw/protocol"
"go.uber.org/thriftrw/protocol/binary"
"go.uber.org/thriftrw/protocol/stream"
"go.uber.org/thriftrw/ptr"
"go.uber.org/thriftrw/wire"
)
Expand All @@ -18,6 +20,12 @@ type thriftType interface {
FromWire(wire.Value) error
}

type streamingThriftType interface {
thriftType

Encode(stream.Writer) error
}

func BenchmarkRoundTrip(b *testing.B) {
type benchCase struct {
name string
Expand Down Expand Up @@ -119,6 +127,19 @@ func BenchmarkRoundTrip(b *testing.B) {
}
}

benchmarkStreamEncode := func(b *testing.B, bb benchCase) {
var buff bytes.Buffer

b.ResetTimer()
for i := 0; i < b.N; i++ {
buff.Reset()

writer := binary.BorrowStreamWriter(&buff)
require.NoError(b, bb.give.(streamingThriftType).Encode(writer), "StreamEncode")
binary.ReturnStreamWriter(writer)
}
}

benchmarkDecode := func(b *testing.B, bb benchCase) {
var buff bytes.Buffer
w, err := bb.give.ToWire()
Expand All @@ -144,6 +165,10 @@ func BenchmarkRoundTrip(b *testing.B) {
benchmarkEncode(b, bb)
})

b.Run("Stream Encode", func(b *testing.B) {
benchmarkStreamEncode(b, bb)
})

b.Run("Decode", func(b *testing.B) {
benchmarkDecode(b, bb)
})
Expand Down
12 changes: 12 additions & 0 deletions gen/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ func TestCollectionsOfPrimitives(t *testing.T) {
for _, tt := range tests {
assertRoundTrip(t, &tt.p, tt.v, tt.desc)
assert.True(t, tt.p.Equals(&tt.p), tt.desc)

testRoundTripCombos(t, &tt.p, tt.v, tt.desc)
assert.True(t, tt.p.Equals(&tt.p), tt.desc)
}
}

Expand Down Expand Up @@ -351,6 +354,9 @@ func TestEnumContainers(t *testing.T) {
for _, tt := range tests {
assertRoundTrip(t, &tt.r, tt.v, "EnumContainers")
assert.True(t, tt.r.Equals(&tt.r), "EnumContainers equal")

testRoundTripCombos(t, &tt.r, tt.v, "EnumContainers")
assert.True(t, tt.r.Equals(&tt.r), "EnumContainers equal")
}
}

Expand Down Expand Up @@ -506,6 +512,9 @@ func TestListOfStructs(t *testing.T) {
for _, tt := range tests {
assertRoundTrip(t, &tt.r, tt.v, "Graph")
assert.True(t, tt.r.Equals(&tt.r), "Graph equal")

testRoundTripCombos(t, &tt.r, tt.v, "Graph")
assert.True(t, tt.r.Equals(&tt.r), "Graph equal")
}
}

Expand Down Expand Up @@ -949,6 +958,9 @@ func TestCrazyTown(t *testing.T) {
for _, tt := range tests {
assertRoundTrip(t, &tt.x, tt.v, tt.desc)
assert.True(t, tt.x.Equals(&tt.x), tt.desc)

testRoundTripCombos(t, &tt.x, tt.v, tt.desc)
assert.True(t, tt.x.Equals(&tt.x), tt.desc)
}
}

Expand Down
15 changes: 15 additions & 0 deletions gen/enum.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func enum(g Generator, spec *compile.EnumSpec) error {
<$math := import "math">
<$strconv := import "strconv">
<$stream := import "go.uber.org/thriftrw/protocol/stream">
<$wire := import "go.uber.org/thriftrw/wire">
<$enumName := goName .Spec>
Expand Down Expand Up @@ -162,6 +163,20 @@ func enum(g Generator, spec *compile.EnumSpec) error {
return &<$v>
}
<$sw := newVar "sw">
// Encode encodes <$enumName> directly to the wire.
//
// sWriter := BinaryStreamer.Writer(writer)
//
// var <$v> <$enumName>
// if err := <$v>.Encode(sWriter); err != nil {
// return err
// }
// return nil
func (<$v> <$enumName>) Encode(<$sw> <$stream>.Writer) error {
return <$sw>.WriteInt32(int32(<$v>))
}
// ToWire translates <$enumName> into a Thrift-level intermediate
// representation. This intermediate representation may be serialized
// into bytes using a ThriftRW protocol implementation.
Expand Down
3 changes: 3 additions & 0 deletions gen/enum_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ func TestEnumDefaultWire(t *testing.T) {

for _, tt := range tests {
assertRoundTrip(t, &tt.e, tt.v, "EnumDefault")
testRoundTripCombos(t, &tt.e, tt.v, "EnumDefault")
}
}

Expand Down Expand Up @@ -154,6 +155,7 @@ func TestEnumWithDuplicateValuesWire(t *testing.T) {

for _, tt := range tests {
assertRoundTrip(t, &tt.e, tt.v, "EnumWithDuplicateValues")
testRoundTripCombos(t, &tt.e, tt.v, "EnumWithDuplicateValues")
}
}

Expand Down Expand Up @@ -185,6 +187,7 @@ func TestOptionalEnum(t *testing.T) {

for _, tt := range tests {
assertRoundTrip(t, &tt.s, tt.v, "StructWithOptionalEnum")
testRoundTripCombos(t, &tt.s, tt.v, "StructWithOptionalEnum")
}
}

Expand Down
105 changes: 105 additions & 0 deletions gen/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ var reservedIdentifiers = map[string]struct{}{
"FromWire": {},
"String": {},
"Equals": {},
"Encode": {},
}

// fieldGroupGenerator is responsible for generating code for FieldGroups.
Expand Down Expand Up @@ -95,6 +96,10 @@ func (f fieldGroupGenerator) Generate(g Generator) error {
return err
}

if err := f.Encode(g); err != nil {
return err
}

if err := f.String(g); err != nil {
return err
}
Expand Down Expand Up @@ -467,6 +472,106 @@ func (f fieldGroupGenerator) FromWire(g Generator) error {
`, f, TemplateFunc("constantValuePtr", ConstantValuePtr))
}

func (f fieldGroupGenerator) Encode(g Generator) error {
return g.DeclareFromTemplate(
`
<$stream := import "go.uber.org/thriftrw/protocol/stream">
<$v := newVar "v">
<$sw := newVar "sw">
// Encode serializes a <.Name> struct directly into bytes, without going
// through an intemediary type.
//
// An error is returned if a <.Name> struct could not be encoded.
func (<$v> *<.Name>) Encode(<$sw> <$stream>.Writer) error {
<- $i := newVar "i" ->
var (
<if len .Fields ->
<$i> int = 0
err error
fh <$stream>.FieldHeader
<- end>
)
if err := <$sw>.WriteStructBegin(); err != nil {
return err
}
<$structName := .Name>
<range .Fields>
<- $fname := goName . ->
<- $f := printf "%s.%s" $v $fname ->
<$t := typeCode .Type>
<- if .Required ->
<- if and (not (isPrimitiveType .Type)) (not (isListType .Type)) ->
if <$f> == nil {
return <import "errors">.New("field <$fname> of <$structName> is required")
}
<- end>
fh = <$stream>.FieldHeader{ID: <.ID>, Type: <$t>,}
if err := <$sw>.WriteFieldBegin(fh); err != nil {
return err
}
if err := <encode .Type $f $sw>; err != nil {
return err
}
if err := <$sw>.WriteFieldEnd(); err != nil {
return err
}
<$i>++
<- else ->
<- if isNotNil .Default ->
<- $fval := printf "%s%s" $v $fname ->
<$fval> := <$f>
if <$fval> == nil {
<$fval> = <constantValuePtr .Default .Type>
}
{
fh = <$stream>.FieldHeader{ID: <.ID>, Type: <$t>,}
if err := <$sw>.WriteFieldBegin(fh); err != nil {
return err
}
if err := <encodePtr .Type $fval $sw>; err != nil {
return err
}
<- else ->
if <$f> != nil {
fh = <$stream>.FieldHeader{ID: <.ID>, Type: <$t>,}
if err := <$sw>.WriteFieldBegin(fh); err != nil {
return err
}
if err := <encodePtr .Type $f $sw>; err != nil {
return err
}
<- end>
if err := <$sw>.WriteFieldEnd(); err != nil {
return err
}
<$i>++
}
<- end>
<end>
<if and .IsUnion (len .Fields)>
<$fmt := import "fmt">
<if .AllowEmptyUnion>
if <$i> > 1 {
return <$fmt>.Errorf("<.Name> should have at most one field: got %v fields", <$i>)
}
<else>
if <$i> != 1 {
return <$fmt>.Errorf("<.Name> should have exactly one field: got %v fields", <$i>)
}
<end>
<end>
return <$sw>.WriteStructEnd()
}
`, f, TemplateFunc("constantValuePtr", ConstantValuePtr))
}

func (f fieldGroupGenerator) String(g Generator) error {
return g.DeclareFromTemplate(
`
Expand Down
3 changes: 3 additions & 0 deletions gen/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ type generator struct {
w WireGenerator
e equalsGenerator
z zapGenerator
s StreamGenerator
noZap bool
decls []ast.Decl
thriftImporter ThriftPackageImporter
Expand Down Expand Up @@ -233,6 +234,8 @@ func (g *generator) TextTemplate(s string, data interface{}, opts ...TemplateOpt
"typeReferencePtr": curryGenerator(typeReferencePtr, g),
"fromWire": curryGenerator(g.w.FromWire, g),
"fromWirePtr": curryGenerator(g.w.FromWirePtr, g),
"encode": curryGenerator(g.s.Encode, g),
"encodePtr": curryGenerator(g.s.EncodePtr, g),
"toWire": curryGenerator(g.w.ToWire, g),
"toWirePtr": curryGenerator(g.w.ToWirePtr, g),
"typeCode": curryGenerator(TypeCode, g),
Expand Down
Loading

0 comments on commit d19e892

Please sign in to comment.