Skip to content

Commit

Permalink
fix: revise
Browse files Browse the repository at this point in the history
  • Loading branch information
ppzqh committed Jan 31, 2024
1 parent dc802ab commit 7a11b36
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 7 deletions.
1 change: 1 addition & 0 deletions pkg/remote/codec/default_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ func (c *defaultCodec) EncodeMetaAndPayloadWithCRC32C(ctx context.Context, messa
strInfo[transmeta.HeaderCRC32C] = string(csb)
}

message.SetPayloadLen(len(pb))
// 2. encode header and return totalLenField if needed
// totalLenField will be filled after payload encoded
if totalLenField, err = ttHeaderCodec.encode(ctx, message, out); err != nil {
Expand Down
67 changes: 66 additions & 1 deletion pkg/remote/codec/default_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"context"
"encoding/binary"
"errors"
"fmt"
"testing"

"github.com/bytedance/mockey"
Expand Down Expand Up @@ -228,7 +229,7 @@ func TestDefaultCodecWithCRC32_Encode_Decode(t *testing.T) {
ctx := context.Background()
intKVInfo := prepareIntKVInfo()
strKVInfo := prepareStrKVInfo()
sendMsg := initClientSendMsg(transport.TTHeader)
sendMsg := initClientSendMsg(transport.TTHeader, 3*1024)
sendMsg.TransInfo().PutTransIntInfo(intKVInfo)
sendMsg.TransInfo().PutTransStrInfo(strKVInfo)

Expand Down Expand Up @@ -286,6 +287,41 @@ func TestCodecTypeNotMatchWithServiceInfoPayloadCodec(t *testing.T) {
test.Assert(t, err == nil)
}

func BenchmarkTestEncodeDecodeWithCRC32C(b *testing.B) {
ctx := context.Background()
remote.PutPayloadCode(serviceinfo.Thrift, mpc)
type factory func() remote.Codec
testCases := map[string]factory{"normal": NewDefaultCodec} //, "crc32c": NewDefaultCodecWithCRC32}

for name, f := range testCases {
b.Run(name, func(b *testing.B) {
for i := 3; i < 4; i++ {
msgLen := i * 1024
b.ReportAllocs()
b.ResetTimer()
b.Run(fmt.Sprintf("%dsize", msgLen), func(b *testing.B) {
for j := 0; j < b.N; j++ {
codec := f()
sendMsg := initClientSendMsg(transport.TTHeader, msgLen)
// encode
out := remote.NewWriterBuffer(256)
err := codec.Encode(ctx, sendMsg, out)
test.Assert(b, err == nil, err)

// decode
recvMsg := initServerRecvMsg()
buf, err := out.Bytes()
test.Assert(b, err == nil, err)
in := remote.NewReaderBuffer(buf)
err = codec.Decode(ctx, recvMsg, in)
test.Assert(b, err == nil, err)
}
})
}
})
}
}

var mpc remote.PayloadCodec = mockPayloadCodec{}

type mockPayloadCodec struct{}
Expand All @@ -294,6 +330,23 @@ func (m mockPayloadCodec) Marshal(ctx context.Context, message remote.Message, o
WriteUint32(ThriftV1Magic+uint32(message.MessageType()), out)
WriteString(message.RPCInfo().Invocation().MethodName(), out)
WriteUint32(uint32(message.RPCInfo().Invocation().SeqID()), out)
var (
dataLen uint32
dataStr string
)
// write data
if data := message.Data(); data != nil {
if mm, ok := data.(*mockMsg); ok {
if len(mm.msg) != 0 {
dataStr = mm.msg
dataLen = uint32(len(mm.msg))
}
}
}
WriteUint32(dataLen, out)
if dataLen > 0 {
WriteString(dataStr, out)
}
return nil
}

Expand Down Expand Up @@ -324,6 +377,18 @@ func (m mockPayloadCodec) Unmarshal(ctx context.Context, message remote.Message,
if err = SetOrCheckSeqID(int32(seqID), message); err != nil && msgType != uint32(remote.Exception) {
return err
}
// read data
dataLen, err := PeekUint32(in)
if err != nil {
return err
}
if dataLen == 0 {
// no data
return nil
}
if _, _, err = ReadString(in); err != nil {
return err
}
return nil
}

Expand Down
5 changes: 5 additions & 0 deletions pkg/remote/codec/header_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,11 @@ func (t ttHeader) encode(ctx context.Context, message remote.Message, out remote
return nil, perrors.NewProtocolErrorWithMsg(fmt.Sprintf("invalid header length[%d]", headerInfoSize))
}
binary.BigEndian.PutUint16(headerInfoSizeField, uint16(headerInfoSize/4))
if message.PayloadLen() != 0 {
// payload encoded before
totalLen := message.PayloadLen() + out.MallocLen() - Size32
binary.BigEndian.PutUint32(totalLenField, uint32(totalLen))
}
return totalLenField, err
}

Expand Down
23 changes: 17 additions & 6 deletions pkg/remote/codec/header_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,29 +305,40 @@ var (
rpcinfo.NewRPCConfig(), rpcinfo.NewRPCStats())
)

type mockMsg struct {
msg string
}

func initServerRecvMsg() remote.Message {
var req interface{}
req := &mockMsg{}
msg := remote.NewMessage(req, mocks.ServiceInfo(), mockSvrRPCInfo, remote.Call, remote.Server)
return msg
}

func initClientSendMsg(tp transport.Protocol) remote.Message {
var req interface{}
func initClientSendMsg(tp transport.Protocol, payloadLen ...int) remote.Message {
req := &mockMsg{}
if len(payloadLen) != 0 {
req.msg = string(make([]byte, payloadLen[0]))
}

svcInfo := mocks.ServiceInfo()
msg := remote.NewMessage(req, svcInfo, mockCliRPCInfo, remote.Call, remote.Client)
msg.SetProtocolInfo(remote.NewProtocolInfo(tp, svcInfo.PayloadCodec))
return msg
}

func initServerSendMsg(tp transport.Protocol) remote.Message {
var resp interface{}
func initServerSendMsg(tp transport.Protocol, payloadLen ...int) remote.Message {
resp := &mockMsg{}
if len(payloadLen) != 0 {
resp.msg = string(make([]byte, payloadLen[0]))
}
msg := remote.NewMessage(resp, mocks.ServiceInfo(), mockSvrRPCInfo, remote.Reply, remote.Server)
msg.SetProtocolInfo(remote.NewProtocolInfo(tp, mocks.ServiceInfo().PayloadCodec))
return msg
}

func initClientRecvMsg() remote.Message {
var resp interface{}
resp := &mockMsg{}
svcInfo := mocks.ServiceInfo()
msg := remote.NewMessage(resp, svcInfo, mockCliRPCInfo, remote.Reply, remote.Client)
return msg
Expand Down

0 comments on commit 7a11b36

Please sign in to comment.