From 20b9045ff06289f32e96014b5c275a388fe5518e Mon Sep 17 00:00:00 2001 From: "qiheng.zhou" Date: Wed, 31 Jan 2024 13:44:58 +0800 Subject: [PATCH] fix: set total length when encode ttheader --- pkg/remote/codec/default_codec.go | 67 +++++++++++++------------- pkg/remote/codec/default_codec_test.go | 11 +++-- pkg/remote/codec/header_codec.go | 7 ++- 3 files changed, 45 insertions(+), 40 deletions(-) diff --git a/pkg/remote/codec/default_codec.go b/pkg/remote/codec/default_codec.go index 39501b11f0..0020f3170f 100644 --- a/pkg/remote/codec/default_codec.go +++ b/pkg/remote/codec/default_codec.go @@ -166,44 +166,37 @@ func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote. func (c *defaultCodec) EncodeMetaAndPayloadWithCRC32C(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error { var err error - var totalLenField []byte // 1. encode payload and calculate crc32c checksum newPayloadOut := remote.NewWriterBuffer(0) + defer newPayloadOut.Release(nil) + if err = me.EncodePayload(ctx, message, newPayloadOut); err != nil { return err } - pb, err := newPayloadOut.Bytes() + payload, err := newPayloadOut.Bytes() if err != nil { return err } - cs := crc32.Checksum(pb, crc32cTable) - csb := make([]byte, Size32) - binary.BigEndian.PutUint32(csb, cs) + crc32c := getCRC32C(payload) if strInfo := message.TransInfo().TransStrInfo(); strInfo != nil { - strInfo[transmeta.HeaderCRC32C] = string(csb) + strInfo[transmeta.HeaderCRC32C] = string(crc32c) } + // set payload length before encode TTHeader. + message.SetPayloadLen(len(payload)) - message.SetPayloadLen(len(pb)) // 2. encode header and return totalLenField if needed // totalLenField will be filled after payload encoded - if totalLenField, err = ttHeaderCodec.encode(ctx, message, out); err != nil { + if _, err = ttHeaderCodec.encode(ctx, message, out); err != nil { return err } // 3. write payload to the buffer after TTHeader - _, err = out.WriteBinary(pb) + _, err = out.WriteBinary(payload) if err != nil { return err } - // 4. fill totalLen field for header if needed - if totalLenField == nil { - return perrors.NewProtocolErrorWithMsg("no buffer allocated for the header length field") - } - payloadLen := out.MallocLen() - Size32 - binary.BigEndian.PutUint32(totalLenField, uint32(payloadLen)) - return nil } @@ -304,23 +297,6 @@ func (c *defaultCodec) encodePayload(ctx context.Context, message remote.Message return pCodec.Marshal(ctx, message, out) } -func checkCRC32C(message remote.Message, in remote.ByteBuffer) error { - crc32byte := message.TransInfo().TransStrInfo()[transmeta.HeaderCRC32C] - if len(crc32byte) != 0 { - expectedChecksum := binary.BigEndian.Uint32([]byte(crc32byte)) - payloadLen := message.PayloadLen() // total length - payload, err := in.Peek(payloadLen) - if err != nil { - return err - } - realChecksum := crc32.Checksum(payload, crc32cTable) - if realChecksum != expectedChecksum { - return perrors.NewProtocolErrorWithType(perrors.InvalidData, "crc32c payload check failed") - } - } - return nil -} - /** * +------------------------------------------------------------+ * | 4Byte | 2Byte | @@ -456,3 +432,28 @@ func checkPayloadSize(payloadLen, maxSize int) error { } return nil } + +// getCRC32C calculates the crc32c checksum of the input bytes +func getCRC32C(payload []byte) []byte { + csb := make([]byte, Size32) + binary.BigEndian.PutUint32(csb, crc32.Checksum(payload, crc32cTable)) + return csb +} + +// checkCRC32C validates the crc32c checksum in the header +func checkCRC32C(message remote.Message, in remote.ByteBuffer) error { + crc32byte := message.TransInfo().TransStrInfo()[transmeta.HeaderCRC32C] + if len(crc32byte) != 0 { + expectedChecksum := binary.BigEndian.Uint32([]byte(crc32byte)) + payloadLen := message.PayloadLen() // total length + payload, err := in.Peek(payloadLen) + if err != nil { + return err + } + realChecksum := crc32.Checksum(payload, crc32cTable) + if realChecksum != expectedChecksum { + return perrors.NewProtocolErrorWithType(perrors.InvalidData, "crc32c payload check failed") + } + } + return nil +} diff --git a/pkg/remote/codec/default_codec_test.go b/pkg/remote/codec/default_codec_test.go index 3657b2d1a8..1028c9ef3c 100644 --- a/pkg/remote/codec/default_codec_test.go +++ b/pkg/remote/codec/default_codec_test.go @@ -291,20 +291,20 @@ func BenchmarkTestEncodeDecodeWithCRC32C(b *testing.B) { ctx := context.Background() remote.PutPayloadCode(serviceinfo.Thrift, mpc) type factory func() remote.Codec - testCases := map[string]factory{"normal": NewDefaultCodec} //, "crc32c": NewDefaultCodecWithCRC32} + testCases := map[string]factory{"normal": NewDefaultCodec, "crc32c": NewDefaultCodecWithCRC32} for name, f := range testCases { b.Run(name, func(b *testing.B) { - for i := 3; i < 4; i++ { - msgLen := i * 1024 + msgLen := 1 + for i := 0; i < 6; i++ { b.ReportAllocs() b.ResetTimer() - b.Run(fmt.Sprintf("%dsize", msgLen), func(b *testing.B) { + b.Run(fmt.Sprintf("payload-%d", msgLen), func(b *testing.B) { for j := 0; j < b.N; j++ { codec := f() sendMsg := initClientSendMsg(transport.TTHeader, msgLen) // encode - out := remote.NewWriterBuffer(256) + out := remote.NewWriterBuffer((i + 1) * 1024) err := codec.Encode(ctx, sendMsg, out) test.Assert(b, err == nil, err) @@ -317,6 +317,7 @@ func BenchmarkTestEncodeDecodeWithCRC32C(b *testing.B) { test.Assert(b, err == nil, err) } }) + msgLen *= 10 } }) } diff --git a/pkg/remote/codec/header_codec.go b/pkg/remote/codec/header_codec.go index df57c4afff..cac94b2fe2 100644 --- a/pkg/remote/codec/header_codec.go +++ b/pkg/remote/codec/header_codec.go @@ -117,6 +117,8 @@ const ( type ttHeader struct{} func (t ttHeader) encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (totalLenField []byte, err error) { + mallocLenBefore := out.MallocLen() + // 1. header meta var headerMeta []byte headerMeta, err = out.Malloc(TTHeaderMetaSize) @@ -154,8 +156,9 @@ func (t ttHeader) encode(ctx context.Context, message remote.Message, out remote } binary.BigEndian.PutUint16(headerInfoSizeField, uint16(headerInfoSize/4)) if message.PayloadLen() != 0 { - // payload encoded before - totalLen := message.PayloadLen() + out.MallocLen() - Size32 + // payload encoded before. set total length here. + headerLen := out.MallocLen() - mallocLenBefore + totalLen := message.PayloadLen() + headerLen - Size32 binary.BigEndian.PutUint32(totalLenField, uint32(totalLen)) } return totalLenField, err