Skip to content

Commit

Permalink
Merge pull request cloudwego#1222 from DMwangnima/feat/fallback-frugal
Browse files Browse the repository at this point in the history
[WIP] feat: implement fallback to frugal feature
  • Loading branch information
Felix021 authored Jan 23, 2024
2 parents 1d6676a + 66178c9 commit ddc46cb
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 32 deletions.
24 changes: 22 additions & 2 deletions pkg/remote/codec/thrift/thrift.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ const (
FastReadWrite = FastRead | FastWrite
)

var defaultCodec = NewThriftCodec().(*thriftCodec)
var (
defaultCodec = NewThriftCodec().(*thriftCodec)

errEncodeMismatchMsgType = remote.NewTransErrorWithMsg(remote.InvalidProtocol,
"encode failed, codec msg type not match with thriftCodec")
errDecodeMismatchMsgType = remote.NewTransErrorWithMsg(remote.InvalidProtocol,
"decode failed, codec msg type not match with thriftCodec")
)

// NewThriftCodec creates the thrift binary codec.
func NewThriftCodec() remote.PayloadCodec {
Expand Down Expand Up @@ -109,7 +116,17 @@ func (c thriftCodec) Marshal(ctx context.Context, message remote.Message, out re
}

// fallback to old thrift way (slow)
return encodeBasicThrift(out, ctx, methodName, msgType, seqID, data)
if err = encodeBasicThrift(out, ctx, methodName, msgType, seqID, data); err == nil || err != errEncodeMismatchMsgType {
return err
}

// Basic can be used for disabling frugal, we need to check it
if c.CodecType != Basic && hyperMarshalAvailable(data) {
// fallback to frugal when the generated code is using slim template
return c.hyperMarshal(out, methodName, msgType, seqID, data)
}

return errEncodeMismatchMsgType
}

// encodeFastThrift encode with the FastCodec way
Expand All @@ -129,6 +146,9 @@ func encodeFastThrift(out remote.ByteBuffer, methodName string, msgType remote.M

// encodeBasicThrift encode with the old thrift way (slow)
func encodeBasicThrift(out remote.ByteBuffer, ctx context.Context, method string, msgType remote.MessageType, seqID int32, data interface{}) error {
if err := verifyMarshalBasicThriftDataType(data); err != nil {
return err
}
tProt := NewBinaryProtocol(out)
if err := tProt.WriteMessageBegin(method, thrift.TMessageType(msgType), seqID); err != nil {
return perrors.NewProtocolErrorWithMsg(fmt.Sprintf("thrift marshal, WriteMessageBegin failed: %s", err.Error()))
Expand Down
59 changes: 51 additions & 8 deletions pkg/remote/codec/thrift/thrift_data.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,15 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([
}
}

if err := verifyMarshalBasicThriftDataType(data); err != nil {
// Basic can be used for disabling frugal, we need to check it
if c.CodecType != Basic && hyperMarshalAvailable(data) {
// fallback to frugal when the generated code is using slim template
return c.hyperMarshalBody(data)
}
return nil, err
}

// fallback to old thrift way (slow)
transport := thrift.NewTMemoryBufferLen(marshalThriftBufferSize)
tProt := thrift.NewTBinaryProtocol(transport, true, true)
Expand All @@ -68,6 +77,17 @@ func (c thriftCodec) marshalThriftData(ctx context.Context, data interface{}) ([
return transport.Bytes(), nil
}

// verifyMarshalBasicThriftDataType verifies whether data could be marshaled by old thrift way
func verifyMarshalBasicThriftDataType(data interface{}) error {
switch data.(type) {
case MessageWriter:
case MessageWriterWithContext:
default:
return errEncodeMismatchMsgType
}
return nil
}

// marshalBasicThriftData only encodes the data (without the prepending method, msgType, seqId)
// It uses the old thrift way which is much slower than FastCodec and Frugal
func marshalBasicThriftData(ctx context.Context, tProt thrift.TProtocol, data interface{}) error {
Expand All @@ -81,7 +101,7 @@ func marshalBasicThriftData(ctx context.Context, tProt thrift.TProtocol, data in
return perrors.NewProtocolErrorWithErrMsg(err, fmt.Sprintf("thrift marshal, Write failed: %s", err.Error()))
}
default:
return remote.NewTransErrorWithMsg(remote.InvalidProtocol, "encode failed, codec msg type not match with thriftCodec")
return errEncodeMismatchMsgType
}
return nil
}
Expand Down Expand Up @@ -122,11 +142,7 @@ func UnmarshalThriftData(ctx context.Context, codec remote.PayloadCodec, method
func (c thriftCodec) unmarshalThriftData(ctx context.Context, tProt *BinaryProtocol, method string, data interface{}, dataLen int) error {
// decode with hyper unmarshal
if c.hyperMessageUnmarshalEnabled() && hyperMessageUnmarshalAvailable(data, dataLen) {
buf, err := tProt.next(dataLen - bthrift.Binary.MessageEndLength())
if err != nil {
return remote.NewTransError(remote.ProtocolError, err)
}
return c.hyperMessageUnmarshal(buf, data)
return c.hyperUnmarshal(tProt, data, dataLen)
}

// decode with FastRead
Expand All @@ -141,10 +157,38 @@ func (c thriftCodec) unmarshalThriftData(ctx context.Context, tProt *BinaryProto
}
}

if err := verifyUnmarshalBasicThriftDataType(data); err != nil {
// Basic can be used for disabling frugal, we need to check it
if c.CodecType != Basic && hyperMessageUnmarshalAvailable(data, dataLen) {
// fallback to frugal when the generated code is using slim template
return c.hyperUnmarshal(tProt, data, dataLen)
}
return err
}

// fallback to old thrift way (slow)
return decodeBasicThriftData(ctx, tProt, method, data)
}

func (c thriftCodec) hyperUnmarshal(tProt *BinaryProtocol, data interface{}, dataLen int) error {
buf, err := tProt.next(dataLen - bthrift.Binary.MessageEndLength())
if err != nil {
return remote.NewTransError(remote.ProtocolError, err)
}
return c.hyperMessageUnmarshal(buf, data)
}

// verifyUnmarshalBasicThriftDataType verifies whether data could be unmarshal by old thrift way
func verifyUnmarshalBasicThriftDataType(data interface{}) error {
switch data.(type) {
case MessageReader:
case MessageReaderWithMethodWithContext:
default:
return errDecodeMismatchMsgType
}
return nil
}

// decodeBasicThriftData decode thrift body the old way (slow)
func decodeBasicThriftData(ctx context.Context, tProt thrift.TProtocol, method string, data interface{}) error {
var err error
Expand All @@ -159,8 +203,7 @@ func decodeBasicThriftData(ctx context.Context, tProt thrift.TProtocol, method s
return remote.NewTransError(remote.ProtocolError, err)
}
default:
return remote.NewTransErrorWithMsg(remote.InvalidProtocol,
"decode failed, codec msg type not match with thriftCodec")
return errDecodeMismatchMsgType
}
return nil
}
14 changes: 13 additions & 1 deletion pkg/remote/codec/thrift/thrift_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ var (
func TestMarshalBasicThriftData(t *testing.T) {
t.Run("invalid-data", func(t *testing.T) {
err := marshalBasicThriftData(context.Background(), nil, 0)
test.Assert(t, err != nil, err)
test.Assert(t, err == errEncodeMismatchMsgType, err)
})
t.Run("valid-data", func(t *testing.T) {
transport := thrift.NewTMemoryBufferLen(1024)
Expand Down Expand Up @@ -138,3 +138,15 @@ func TestUnmarshalThriftException(t *testing.T) {
test.Assert(t, transErr.TypeID() == thrift.INVALID_PROTOCOL, transErr)
test.Assert(t, transErr.Error() == errMessage, transErr)
}

func Test_verifyMarshalBasicThriftDataType(t *testing.T) {
err := verifyMarshalBasicThriftDataType(&mockWithContext{})
test.Assert(t, err == nil, err)
// data that is not part of basic thrift: in thrift_frugal_amd64_test.go: Test_verifyMarshalThriftDataFrugal
}

func Test_verifyUnmarshalBasicThriftDataType(t *testing.T) {
err := verifyUnmarshalBasicThriftDataType(&mockWithContext{})
test.Assert(t, err == nil, err)
// data that is not part of basic thrift: in thrift_frugal_amd64_test.go: Test_verifyUnmarshalThriftDataFrugal
}
108 changes: 87 additions & 21 deletions pkg/remote/codec/thrift/thrift_frugal_amd64_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,45 @@ func TestHyperCodecCheck(t *testing.T) {
}

func TestFrugalCodec(t *testing.T) {
ctx := context.Background()
frugalCodec := &thriftCodec{FrugalRead | FrugalWrite}
t.Run("configure frugal but data has not tag", func(t *testing.T) {
ctx := context.Background()
codec := &thriftCodec{FrugalRead | FrugalWrite}

// MockNoTagArgs cannot be marshaled
sendMsg := initNoTagSendMsg(transport.TTHeader)
out := remote.NewWriterBuffer(256)
err := codec.Marshal(ctx, sendMsg, out)
test.Assert(t, err != nil)
})
t.Run("configure frugal and data has tag", func(t *testing.T) {
ctx := context.Background()
codec := &thriftCodec{FrugalRead | FrugalWrite}

testFrugalDataConversion(t, ctx, codec)
})
t.Run("fallback to frugal and data has tag", func(t *testing.T) {
ctx := context.Background()
codec := NewThriftCodec()

testFrugalDataConversion(t, ctx, codec)
})
t.Run("configure BasicCodec to disable frugal fallback", func(t *testing.T) {
ctx := context.Background()
codec := NewThriftCodecWithConfig(Basic)

// MockNoTagArgs cannot be marshaled
sendMsg := initNoTagSendMsg(transport.TTHeader)
out := remote.NewWriterBuffer(256)
err := codec.Marshal(ctx, sendMsg, out)
test.Assert(t, err != nil)
})
}

func testFrugalDataConversion(t *testing.T, ctx context.Context, codec remote.PayloadCodec) {
// encode client side
// MockNoTagArgs cannot be marshaled
sendMsg := initNoTagSendMsg(transport.TTHeader)
sendMsg := initFrugalTagSendMsg(transport.TTHeader)
out := remote.NewWriterBuffer(256)
err := frugalCodec.Marshal(ctx, sendMsg, out)
test.Assert(t, err != nil)

// MockFrugalTagArgs can be marshaled by frugal
sendMsg = initFrugalTagSendMsg(transport.TTHeader)
out = remote.NewWriterBuffer(256)
err = frugalCodec.Marshal(ctx, sendMsg, out)
err := codec.Marshal(ctx, sendMsg, out)
test.Assert(t, err == nil, err)

// decode server side
Expand All @@ -123,7 +148,7 @@ func TestFrugalCodec(t *testing.T) {
recvMsg.SetPayloadLen(len(buf))
test.Assert(t, err == nil, err)
in := remote.NewReaderBuffer(buf)
err = frugalCodec.Unmarshal(ctx, recvMsg, in)
err = codec.Unmarshal(ctx, recvMsg, in)
test.Assert(t, err == nil, err)

// compare Args
Expand All @@ -144,17 +169,58 @@ func TestMarshalThriftDataFrugal(t *testing.T) {
mockReqFrugal := &MockFrugalTagReq{
Msg: "hello",
}
buf, err := MarshalThriftData(context.Background(), NewThriftCodecWithConfig(FrugalWrite), mockReqFrugal)
test.Assert(t, err == nil, err)
test.Assert(t, reflect.DeepEqual(buf, mockReqThrift), buf)
successfulCodecs := []remote.PayloadCodec{
NewThriftCodecWithConfig(FrugalWrite),
// fallback to frugal
nil,
// fallback to frugal
NewThriftCodec(),
}
for _, codec := range successfulCodecs {
buf, err := MarshalThriftData(context.Background(), codec, mockReqFrugal)
test.Assert(t, err == nil, err)
test.Assert(t, reflect.DeepEqual(buf, mockReqThrift), buf)
}

// Basic can be used for disabling frugal
_, err := MarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), mockReqFrugal)
test.Assert(t, err != nil, err)
}

func TestUnmarshalThriftDataFrugal(t *testing.T) {
req := &MockFrugalTagReq{}
err := UnmarshalThriftData(context.Background(), NewThriftCodecWithConfig(FrugalRead), "mock", mockReqThrift, req)
checkDecodeResult(t, err, &fast.MockReq{
Msg: req.Msg,
StrList: req.StrList,
StrMap: req.StrMap,
})
successfulCodecs := []remote.PayloadCodec{
NewThriftCodecWithConfig(FrugalRead),
// fallback to frugal
nil,
// fallback to frugal
NewThriftCodec(),
}
for _, codec := range successfulCodecs {
err := UnmarshalThriftData(context.Background(), codec, "mock", mockReqThrift, req)
checkDecodeResult(t, err, &fast.MockReq{
Msg: req.Msg,
StrList: req.StrList,
StrMap: req.StrMap,
})

}

// Basic can be used for disabling frugal
err := UnmarshalThriftData(context.Background(), NewThriftCodecWithConfig(Basic), "mock", mockReqThrift, req)
test.Assert(t, err != nil, err)
}

func Test_verifyMarshalThriftDataFrugal(t *testing.T) {
err := verifyMarshalBasicThriftDataType(&MockFrugalTagArgs{})
test.Assert(t, err == errEncodeMismatchMsgType, err)
err = verifyMarshalBasicThriftDataType(&MockNoTagArgs{})
test.Assert(t, err == errEncodeMismatchMsgType, err)
}

func Test_verifyUnmarshalThriftDataFrugal(t *testing.T) {
err := verifyUnmarshalBasicThriftDataType(&MockFrugalTagArgs{})
test.Assert(t, err == errDecodeMismatchMsgType, err)
err = verifyUnmarshalBasicThriftDataType(&MockNoTagArgs{})
test.Assert(t, err == errDecodeMismatchMsgType, err)
}

0 comments on commit ddc46cb

Please sign in to comment.