Skip to content

Commit

Permalink
import rpcmetadata and delete copy of generated code
Browse files Browse the repository at this point in the history
Summary: import rpcmetadata and delete copy of generated code

Reviewed By: echistyakov

Differential Revision: D63016020

fbshipit-source-id: 31ced0ae225a41629b6e4f483844c1e6aad0be52
  • Loading branch information
Walter Schulze authored and facebook-github-bot committed Sep 20, 2024
1 parent 2cbee70 commit 16d1f00
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 10,839 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,41 +19,42 @@ package thrift
import (
"testing"

"github.com/facebook/fbthrift/thrift/lib/thrift/rpcmetadata"
"github.com/stretchr/testify/assert"
)

func TestNewRocketException(t *testing.T) {
kind := ErrorKind_PERMANENT
blame := ErrorBlame_CLIENT
safety := ErrorSafety_SAFE
class := &ErrorClassification{
kind := rpcmetadata.ErrorKind_PERMANENT
blame := rpcmetadata.ErrorBlame_CLIENT
safety := rpcmetadata.ErrorSafety_SAFE
class := &rpcmetadata.ErrorClassification{
Kind: &kind,
Blame: &blame,
Safety: &safety,
}
declaredException := &PayloadExceptionMetadataBase{
Metadata: &PayloadExceptionMetadata{
DeclaredException: &PayloadDeclaredExceptionMetadata{
declaredException := &rpcmetadata.PayloadExceptionMetadataBase{
Metadata: &rpcmetadata.PayloadExceptionMetadata{
DeclaredException: &rpcmetadata.PayloadDeclaredExceptionMetadata{
ErrorClassification: class,
},
},
}
err := newRocketException(declaredException)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "DeclaredException")
appUnknownException := &PayloadExceptionMetadataBase{
Metadata: &PayloadExceptionMetadata{
AppUnknownException: &PayloadAppUnknownExceptionMetdata{
appUnknownException := &rpcmetadata.PayloadExceptionMetadataBase{
Metadata: &rpcmetadata.PayloadExceptionMetadata{
AppUnknownException: &rpcmetadata.PayloadAppUnknownExceptionMetdata{
ErrorClassification: class,
},
},
}
err = newRocketException(appUnknownException)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "AppUnknownException")
appUnknownException2 := &PayloadExceptionMetadataBase{
Metadata: &PayloadExceptionMetadata{
AppUnknownException: &PayloadAppUnknownExceptionMetdata{},
appUnknownException2 := &rpcmetadata.PayloadExceptionMetadataBase{
Metadata: &rpcmetadata.PayloadExceptionMetadata{
AppUnknownException: &rpcmetadata.PayloadAppUnknownExceptionMetdata{},
},
}
err = newRocketException(appUnknownException2)
Expand Down
42 changes: 22 additions & 20 deletions third-party/thrift/src/thrift/lib/go/thrift/rocket_exception.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ package thrift

import (
"encoding/json"

"github.com/facebook/fbthrift/thrift/lib/thrift/rpcmetadata"
)

type rocketExceptionType int16
Expand Down Expand Up @@ -54,29 +56,29 @@ type rocketException struct {
Name string
What string
ExceptionType rocketExceptionType
Kind ErrorKind
Blame ErrorBlame
Safety ErrorSafety
Kind rpcmetadata.ErrorKind
Blame rpcmetadata.ErrorBlame
Safety rpcmetadata.ErrorSafety
}

var _ error = (*rocketException)(nil)

func newRocketException(exception *PayloadExceptionMetadataBase) *rocketException {
func newRocketException(exception *rpcmetadata.PayloadExceptionMetadataBase) *rocketException {
err := &rocketException{
Name: "unknown",
What: "unknown",
ExceptionType: rocketExceptionUnknown,
Kind: ErrorKind_UNSPECIFIED,
Blame: ErrorBlame_UNSPECIFIED,
Safety: ErrorSafety_UNSPECIFIED,
Kind: rpcmetadata.ErrorKind_UNSPECIFIED,
Blame: rpcmetadata.ErrorBlame_UNSPECIFIED,
Safety: rpcmetadata.ErrorSafety_UNSPECIFIED,
}
if exception.NameUTF8 != nil {
err.Name = *exception.NameUTF8
}
if exception.WhatUTF8 != nil {
err.What = *exception.WhatUTF8
}
var class *ErrorClassification
var class *rpcmetadata.ErrorClassification
if exception.Metadata != nil {
if exception.Metadata.DeclaredException != nil {
err.ExceptionType = rocketExceptionDeclared
Expand Down Expand Up @@ -108,39 +110,39 @@ func newRocketException(exception *PayloadExceptionMetadataBase) *rocketExceptio
return err
}

func newUnknownPayloadExceptionMetadataBase(name string, what string) *PayloadExceptionMetadataBase {
func newUnknownPayloadExceptionMetadataBase(name string, what string) *rpcmetadata.PayloadExceptionMetadataBase {
return newPayloadExceptionMetadataBase(&rocketException{
Name: name,
What: what,
ExceptionType: rocketExceptionUnknown,
Safety: ErrorSafety_SAFE,
Kind: ErrorKind_TRANSIENT,
Blame: ErrorBlame_SERVER,
Safety: rpcmetadata.ErrorSafety_SAFE,
Kind: rpcmetadata.ErrorKind_TRANSIENT,
Blame: rpcmetadata.ErrorBlame_SERVER,
})
}

func newPayloadExceptionMetadataBase(err *rocketException) *PayloadExceptionMetadataBase {
base := NewPayloadExceptionMetadataBase()
func newPayloadExceptionMetadataBase(err *rocketException) *rpcmetadata.PayloadExceptionMetadataBase {
base := rpcmetadata.NewPayloadExceptionMetadataBase()
base.SetNameUTF8(&err.Name)
base.SetWhatUTF8(&err.What)
class := NewErrorClassification()
class := rpcmetadata.NewErrorClassification()
class.SetKind(&err.Kind)
class.SetBlame(&err.Blame)
class.SetSafety(&err.Safety)
metadata := NewPayloadExceptionMetadata()
metadata := rpcmetadata.NewPayloadExceptionMetadata()
switch err.ExceptionType {
case rocketExceptionDeclared:
declared := NewPayloadDeclaredExceptionMetadata()
declared := rpcmetadata.NewPayloadDeclaredExceptionMetadata()
declared.SetErrorClassification(class)
metadata.SetDeclaredException(declared)
case rocketExceptionAppUnknown:
appUnknown := NewPayloadAppUnknownExceptionMetdata()
appUnknown := rpcmetadata.NewPayloadAppUnknownExceptionMetdata()
appUnknown.SetErrorClassification(class)
metadata.SetAppUnknownException(appUnknown)
case rocketExceptionAny:
metadata.SetAnyException(NewPayloadAnyExceptionMetadata())
metadata.SetAnyException(rpcmetadata.NewPayloadAnyExceptionMetadata())
case rocketExceptionDeprecatedProxy:
metadata.SetDEPRECATEDProxyException(NewPayloadProxyExceptionMetadata())
metadata.SetDEPRECATEDProxyException(rpcmetadata.NewPayloadProxyExceptionMetadata())
case rocketExceptionUnknown:
default:
panic("unreachable")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package thrift
import (
"fmt"

"github.com/facebook/fbthrift/thrift/lib/thrift/rpcmetadata"
"github.com/rsocket/rsocket-go/payload"
)

Expand All @@ -39,7 +40,7 @@ func decodeServerMetadataPushVersion8(msg payload.Payload) (*serverMetadataPaylo
return nil, fmt.Errorf("no metadata in server metadata push")
}
// Use ServerPushMetadata{} and do not use &ServerPushMetadata{} to ensure stack and avoid heap allocation.
metadata := ServerPushMetadata{}
metadata := rpcmetadata.ServerPushMetadata{}
if err := deserializeCompact(metadataBytes, &metadata); err != nil {
panic(fmt.Errorf("unable to deserialize metadata push into ServerPushMetadata %w", err))
}
Expand All @@ -62,8 +63,8 @@ func decodeServerMetadataPushVersion8(msg payload.Payload) (*serverMetadataPaylo

func encodeServerMetadataPushVersion8(zstdSupported bool) (payload.Payload, error) {
version := int32(8)
res := NewServerPushMetadata().
SetSetupResponse(NewSetupResponse().
res := rpcmetadata.NewServerPushMetadata().
SetSetupResponse(rpcmetadata.NewSetupResponse().
SetVersion(&version).
SetZstdSupported(&zstdSupported))
metadataBytes, err := serializeCompact(res)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,19 @@ import (
"maps"

"github.com/facebook/fbthrift/thrift/lib/go/thrift/types"
"github.com/facebook/fbthrift/thrift/lib/thrift/rpcmetadata"
"github.com/rsocket/rsocket-go/payload"
)

type requestPayload struct {
metadata *RequestRpcMetadata
metadata *rpcmetadata.RequestRpcMetadata
data []byte
typeID types.MessageType
protoID types.ProtocolID
}

func encodeRequestPayload(name string, protoID types.ProtocolID, typeID types.MessageType, headers map[string]string, zstd bool, dataBytes []byte) (payload.Payload, error) {
metadata := NewRequestRpcMetadata()
metadata := rpcmetadata.NewRequestRpcMetadata()
metadata.SetName(&name)
rpcProtocolID, err := protocolIDToRPCProtocolID(protoID)
if err != nil {
Expand All @@ -45,7 +46,7 @@ func encodeRequestPayload(name string, protoID types.ProtocolID, typeID types.Me
}
metadata.SetKind(&kind)
if zstd {
compression := CompressionAlgorithm_ZSTD
compression := rpcmetadata.CompressionAlgorithm_ZSTD
metadata.SetCompression(&compression)
}
metadata.OtherMetadata = make(map[string]string)
Expand All @@ -70,7 +71,7 @@ func decodeRequestPayload(msg payload.Payload) (*requestPayload, error) {
var err error
metadataBytes, ok := msg.Metadata()
if ok {
metadata := &RequestRpcMetadata{}
metadata := &rpcmetadata.RequestRpcMetadata{}
if err := deserializeCompact(metadataBytes, metadata); err != nil {
return nil, err
}
Expand Down Expand Up @@ -117,7 +118,7 @@ func (r *requestPayload) ProtoID() types.ProtocolID {
}

func (r *requestPayload) Zstd() bool {
return r.metadata != nil && r.metadata.GetCompression() == CompressionAlgorithm_ZSTD
return r.metadata != nil && r.metadata.GetCompression() == rpcmetadata.CompressionAlgorithm_ZSTD
}

func (r *requestPayload) Headers() map[string]string {
Expand All @@ -127,41 +128,41 @@ func (r *requestPayload) Headers() map[string]string {
return r.metadata.GetOtherMetadata()
}

func protocolIDToRPCProtocolID(protocolID types.ProtocolID) (ProtocolId, error) {
func protocolIDToRPCProtocolID(protocolID types.ProtocolID) (rpcmetadata.ProtocolId, error) {
switch protocolID {
case types.ProtocolIDBinary:
return ProtocolId_BINARY, nil
return rpcmetadata.ProtocolId_BINARY, nil
case types.ProtocolIDCompact:
return ProtocolId_COMPACT, nil
return rpcmetadata.ProtocolId_COMPACT, nil
}
return 0, fmt.Errorf("unsupported ProtocolID %v", protocolID)
}

func rpcProtocolIDToProtocolID(protocolID ProtocolId) (types.ProtocolID, error) {
func rpcProtocolIDToProtocolID(protocolID rpcmetadata.ProtocolId) (types.ProtocolID, error) {
switch protocolID {
case ProtocolId_BINARY:
case rpcmetadata.ProtocolId_BINARY:
return types.ProtocolIDBinary, nil
case ProtocolId_COMPACT:
case rpcmetadata.ProtocolId_COMPACT:
return types.ProtocolIDCompact, nil
}
return 0, fmt.Errorf("unsupported ProtocolId %v", protocolID)
}

func messageTypeToRPCKind(typeID types.MessageType) (RpcKind, error) {
func messageTypeToRPCKind(typeID types.MessageType) (rpcmetadata.RpcKind, error) {
switch typeID {
case types.CALL:
return RpcKind_SINGLE_REQUEST_SINGLE_RESPONSE, nil
return rpcmetadata.RpcKind_SINGLE_REQUEST_SINGLE_RESPONSE, nil
case types.ONEWAY:
return RpcKind_SINGLE_REQUEST_NO_RESPONSE, nil
return rpcmetadata.RpcKind_SINGLE_REQUEST_NO_RESPONSE, nil
}
return 0, fmt.Errorf("unsupported MessageType %v", typeID)
}

func rpcKindToMessageType(kind RpcKind) (types.MessageType, error) {
func rpcKindToMessageType(kind rpcmetadata.RpcKind) (types.MessageType, error) {
switch kind {
case RpcKind_SINGLE_REQUEST_SINGLE_RESPONSE:
case rpcmetadata.RpcKind_SINGLE_REQUEST_SINGLE_RESPONSE:
return types.CALL, nil
case RpcKind_SINGLE_REQUEST_NO_RESPONSE:
case rpcmetadata.RpcKind_SINGLE_REQUEST_NO_RESPONSE:
return types.ONEWAY, nil
}
return 0, fmt.Errorf("unsupported RpcKind %v", kind)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@ package thrift

import (
"github.com/facebook/fbthrift/thrift/lib/go/thrift/types"
"github.com/facebook/fbthrift/thrift/lib/thrift/rpcmetadata"
"github.com/rsocket/rsocket-go/payload"
)

type responsePayload struct {
metadata *ResponseRpcMetadata
metadata *rpcmetadata.ResponseRpcMetadata
exception *rocketException
data []byte
}
Expand All @@ -39,7 +40,7 @@ func (r *responsePayload) Data() []byte {
}

func (r *responsePayload) Zstd() bool {
return r.metadata != nil && r.metadata.GetCompression() == CompressionAlgorithm_ZSTD
return r.metadata != nil && r.metadata.GetCompression() == rpcmetadata.CompressionAlgorithm_ZSTD
}

func (r *responsePayload) Error() error {
Expand All @@ -52,9 +53,9 @@ func (r *responsePayload) Error() error {
func decodeResponsePayload(msg payload.Payload) (*responsePayload, error) {
msg = payload.Clone(msg)
if msg == nil {
return &responsePayload{metadata: &ResponseRpcMetadata{}, data: []byte{}}, nil
return &responsePayload{metadata: &rpcmetadata.ResponseRpcMetadata{}, data: []byte{}}, nil
}
res := &responsePayload{metadata: &ResponseRpcMetadata{}, data: msg.Data()}
res := &responsePayload{metadata: &rpcmetadata.ResponseRpcMetadata{}, data: msg.Data()}
var err error
metadataBytes, ok := msg.Metadata()
if ok {
Expand All @@ -75,15 +76,15 @@ func decodeResponsePayload(msg payload.Payload) (*responsePayload, error) {
}

func encodeResponsePayload(name string, messageType types.MessageType, headers map[string]string, zstd bool, dataBytes []byte) (payload.Payload, error) {
metadata := NewResponseRpcMetadata()
metadata := rpcmetadata.NewResponseRpcMetadata()
metadata.SetOtherMetadata(headers)
if zstd {
compression := CompressionAlgorithm_ZSTD
compression := rpcmetadata.CompressionAlgorithm_ZSTD
metadata.SetCompression(&compression)
}
if messageType == types.EXCEPTION {
excpetionMetadata := newUnknownPayloadExceptionMetadataBase(name, string(dataBytes))
metadata.SetPayloadMetadata(NewPayloadMetadata().SetExceptionMetadata(excpetionMetadata))
metadata.SetPayloadMetadata(rpcmetadata.NewPayloadMetadata().SetExceptionMetadata(excpetionMetadata))
}
metadataBytes, err := serializeCompact(metadata)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"encoding/binary"
"fmt"

"github.com/facebook/fbthrift/thrift/lib/thrift/rpcmetadata"
"github.com/rsocket/rsocket-go/payload"
)

Expand All @@ -40,10 +41,10 @@ func checkRequestSetupMetadata8(pay payload.Payload) error {
if err != nil {
return err
}
if int64(key) != KRocketProtocolKey {
return fmt.Errorf("expected key %d, got %d", KRocketProtocolKey, key)
if int64(key) != rpcmetadata.KRocketProtocolKey {
return fmt.Errorf("expected key %d, got %d", rpcmetadata.KRocketProtocolKey, key)
}
req := RequestSetupMetadata{}
req := rpcmetadata.RequestSetupMetadata{}
if err := deserializeCompact(metdataBytes[4:], &req); err != nil {
return err
}
Expand All @@ -55,8 +56,8 @@ func checkRequestSetupMetadata8(pay payload.Payload) error {
return nil
}

func newRequestSetupMetadataVersion8() *RequestSetupMetadata {
res := NewRequestSetupMetadata()
func newRequestSetupMetadataVersion8() *rpcmetadata.RequestSetupMetadata {
res := rpcmetadata.NewRequestSetupMetadata()
version := int32(8)
res.SetMaxVersion(&version)
res.SetMinVersion(&version)
Expand All @@ -67,7 +68,7 @@ func newRequestSetupMetadataVersion8() *RequestSetupMetadata {
func newRequestSetupMetadataVersion8Bytes() ([]byte, error) {
// write key first rpcmetadata.KRocketProtocolKey
buf := new(bytes.Buffer)
key := uint32(KRocketProtocolKey)
key := uint32(rpcmetadata.KRocketProtocolKey)
err := binary.Write(buf, binary.BigEndian, key)
if err != nil {
return nil, err
Expand All @@ -91,8 +92,8 @@ func newRequestSetupPayloadVersion8() (payload.Payload, error) {
}

// If connection establishment was successful, the server MUST respond with a SetupResponse control message.
func newSetupResponseVersion8() *SetupResponse {
res := NewSetupResponse()
func newSetupResponseVersion8() *rpcmetadata.SetupResponse {
res := rpcmetadata.NewSetupResponse()
version := int32(8)
res.SetVersion(&version)
ztsdSupported := true
Expand Down
Loading

0 comments on commit 16d1f00

Please sign in to comment.