Skip to content

Commit

Permalink
feat: crc32 check
Browse files Browse the repository at this point in the history
  • Loading branch information
ppzqh committed Jan 30, 2024
1 parent 1167bb5 commit 364b4df
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 4 deletions.
78 changes: 77 additions & 1 deletion pkg/remote/codec/default_codec.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@ import (
"context"
"encoding/binary"
"fmt"
"hash/crc32"
"sync/atomic"

"github.com/cloudwego/kitex/pkg/kerrors"
"github.com/cloudwego/kitex/pkg/remote"
"github.com/cloudwego/kitex/pkg/remote/codec/perrors"
"github.com/cloudwego/kitex/pkg/remote/transmeta"
"github.com/cloudwego/kitex/pkg/retry"
"github.com/cloudwego/kitex/pkg/rpcinfo"
"github.com/cloudwego/kitex/pkg/serviceinfo"
Expand All @@ -37,6 +39,8 @@ const (
Size16 = 2
)

var crcTable = crc32.MakeTable(0xD5828281)

const (
// ThriftV1Magic is the magic code for thrift.VERSION_1
ThriftV1Magic = 0x80010000
Expand Down Expand Up @@ -71,9 +75,15 @@ func NewDefaultCodecWithSizeLimit(maxSize int) remote.Codec {
}
}

// NewDefaultCodecWithCRC32 creates the default protocol sniffing codec supporting thrift and protobuf with crc32 check.
func NewDefaultCodecWithCRC32() remote.Codec {
return &defaultCodec{crc32Check: true}
}

type defaultCodec struct {
// maxSize limits the max size of the payload
maxSize int
maxSize int
crc32Check bool
}

// EncodePayload encode payload
Expand Down Expand Up @@ -150,8 +160,54 @@ func (c *defaultCodec) EncodeMetaAndPayload(ctx context.Context, message remote.
return nil
}

func (c *defaultCodec) EncodeMetaAndPayloadWithCRC32C(ctx context.Context, message remote.Message, out remote.ByteBuffer, me remote.MetaEncoder) error {
var err error
var totalLenField []byte
tp := message.ProtocolInfo().TransProto

// 1. encode payload and calculate crc32 checksum
newPayloadOut := remote.NewWriterBuffer(0)
if err = me.EncodePayload(ctx, message, newPayloadOut); err != nil {
return err
}
pb, err := newPayloadOut.Bytes()
if err != nil {
return err
}
cs := crc32.Checksum(pb, crcTable)
strinfo := message.TransInfo().TransStrInfo()
csb := make([]byte, Size32)
binary.BigEndian.PutUint32(csb, cs)
strinfo[transmeta.HeaderCRC32C] = string(csb)

// 2. encode header and return totalLenField if needed
// totalLenField will be filled after payload encoded
if tp&transport.TTHeader == transport.TTHeader {
if totalLenField, err = ttHeaderCodec.encode(ctx, message, out); err != nil {
return err
}
}

// 3. write payload to the buffer after ttheader
out.WriteBinary(pb)

// 4. fill totalLen field for header if needed
if tp&transport.TTHeader == transport.TTHeader {
if totalLenField == nil {
return perrors.NewProtocolErrorWithMsg("no buffer allocated for the header length field")
}
payloadLen := out.MallocLen() - Size32
binary.BigEndian.PutUint32(totalLenField, uint32(payloadLen))
}
return nil
}

// Encode implements the remote.Codec interface, it does complete message encode include header and payload.
func (c *defaultCodec) Encode(ctx context.Context, message remote.Message, out remote.ByteBuffer) (err error) {
c.crc32Check = true
if c.crc32Check {
return c.EncodeMetaAndPayloadWithCRC32C(ctx, message, out, c)
}
return c.EncodeMetaAndPayload(ctx, message, out, c)
}

Expand All @@ -176,6 +232,9 @@ func (c *defaultCodec) DecodeMeta(ctx context.Context, message remote.Message, i
if flagBuf, err = in.Peek(2 * Size32); err != nil {
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("ttheader read payload first 8 byte failed: %s", err.Error()))
}
if err = c.checkCRC32C(message, in); err != nil {
return err
}
} else if isMeshHeader(flagBuf) {
message.Tags()[remote.MeshHeader] = true
// MeshHeader
Expand Down Expand Up @@ -238,6 +297,23 @@ func (c *defaultCodec) encodePayload(ctx context.Context, message remote.Message
return pCodec.Marshal(ctx, message, out)
}

func (c *defaultCodec) checkCRC32C(message remote.Message, in remote.ByteBuffer) error {
crc32byte := message.TransInfo().TransStrInfo()[transmeta.HeaderCRC32C]
if len(crc32byte) != 0 {
expectedChecksum := binary.BigEndian.Uint32([]byte(crc32byte))
payloadLen := message.PayloadLen()
payload, err := in.Peek(payloadLen)
if err != nil {
return err
}
realChecksum := crc32.Checksum(payload, crcTable)
if realChecksum != expectedChecksum {
return perrors.NewProtocolErrorWithMsg("crc32c check failed")
}
}
return nil
}

/**
* +------------------------------------------------------------+
* | 4Byte | 2Byte |
Expand Down
36 changes: 36 additions & 0 deletions pkg/remote/codec/default_codec_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,42 @@ func TestDefaultSizedCodec_Encode_Decode(t *testing.T) {
test.Assert(t, err == nil, err)
}

func TestDefaultCodecWithCRC32_Encode_Decode(t *testing.T) {
remote.PutPayloadCode(serviceinfo.Thrift, mpc)

dc := NewDefaultCodecWithCRC32()
ctx := context.Background()
intKVInfo := prepareIntKVInfo()
strKVInfo := prepareStrKVInfo()
sendMsg := initClientSendMsg(transport.TTHeader)
sendMsg.TransInfo().PutTransIntInfo(intKVInfo)
sendMsg.TransInfo().PutTransStrInfo(strKVInfo)

// test encode err
out := remote.NewReaderBuffer([]byte{})
err := dc.Encode(ctx, sendMsg, out)
test.Assert(t, err != nil)

// encode
out = remote.NewWriterBuffer(256)
err = dc.Encode(ctx, sendMsg, out)
test.Assert(t, err == nil, err)

// decode
recvMsg := initServerRecvMsg()
buf, err := out.Bytes()
test.Assert(t, err == nil, err)
in := remote.NewReaderBuffer(buf)
err = dc.Decode(ctx, recvMsg, in)
test.Assert(t, err == nil, err)

intKVInfoRecv := recvMsg.TransInfo().TransIntInfo()
strKVInfoRecv := recvMsg.TransInfo().TransStrInfo()
test.DeepEqual(t, intKVInfoRecv, intKVInfo)
test.DeepEqual(t, strKVInfoRecv, strKVInfo)
test.Assert(t, sendMsg.RPCInfo().Invocation().SeqID() == recvMsg.RPCInfo().Invocation().SeqID())
}

func TestCodecTypeNotMatchWithServiceInfoPayloadCodec(t *testing.T) {
var req interface{}
remote.PutPayloadCode(serviceinfo.Thrift, mpc)
Expand Down
27 changes: 27 additions & 0 deletions pkg/remote/connpool/long_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,33 @@ var (
mockAddr1 = "127.0.0.1:8001"
)

func TestPoolReuseRate(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()

var (
minIdle = 0
maxIdle = 10000
maxIdleTimeout = time.Hour
)

p := newPool(minIdle, maxIdle, maxIdleTimeout)
num := 1000
for i := 0; i < num; i++ {
conn := newLongConnForTest(ctrl, mockAddr0)
recycled := p.Put(conn)
test.Assert(t, recycled == true)
}
evictTime := 10
for i := 0; i < evictTime; i++ {
test.Assert(t, p.Len() == num)
p.Evict()
test.Assert(t, p.Len() == num)
time.Sleep(100 * time.Millisecond)
}

}

func TestPoolReuse(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
Expand Down
7 changes: 4 additions & 3 deletions pkg/remote/default_bytebuf.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,10 @@ func (b *defaultByteBuffer) Bytes() (buf []byte, err error) {
if b.status&BitWritable == 0 {
return nil, errors.New("unwritable buffer, cannot support Bytes")
}
buf = make([]byte, b.writeIdx)
copy(buf, b.buff[:b.writeIdx])
return buf, nil
return b.buff[:b.writeIdx], nil
// buf = make([]byte, b.writeIdx)
// copy(buf, b.buff[:b.writeIdx])
// return buf, nil
}

// NewBuffer returns a new writable remote.ByteBuffer.
Expand Down
1 change: 1 addition & 0 deletions pkg/remote/transmeta/metakey.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ const (
// the connection peer will shutdown later,so it send back the header to tell client to close the connection.
HeaderConnectionReadyToReset = "crrst"
HeaderProcessAtTime = "K_ProcessAtTime"
HeaderCRC32C = "crc32c"
)

// key of acl token
Expand Down

0 comments on commit 364b4df

Please sign in to comment.