Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(baseapp): Add Hybrid Protobuf handlers to MsgServiceRouter #18071

Merged
merged 10 commits into from
Oct 12, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ Ref: https://keepachangelog.com/en/1.0.0/
* (client) [#17513](https://github.com/cosmos/cosmos-sdk/pull/17513) Allow overwritting `client.toml`. Use `client.CreateClientConfig` in place of `client.ReadFromClientConfig` and provide a custom template and a custom config.
* (x/bank) [#17569](https://github.com/cosmos/cosmos-sdk/pull/17569) Introduce a new message type, `MsgBurn `, to burn coins.
* (server) [#17094](https://github.com/cosmos/cosmos-sdk/pull/17094) Add duration `shutdown-grace` for resource clean up (closing database handles) before exit.
* (baseapp) [#18071](https://github.com/cosmos/cosmos-sdk/pull/18071) Add hybrid handlers to `MsgServiceRouter`.

### Improvements

Expand Down
30 changes: 11 additions & 19 deletions baseapp/grpcrouter.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@ import (

abci "github.com/cometbft/cometbft/abci/types"
gogogrpc "github.com/cosmos/gogoproto/grpc"
"github.com/cosmos/gogoproto/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/encoding"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/runtime/protoiface"

"github.com/cosmos/cosmos-sdk/baseapp/internal/protocompat"
Expand All @@ -23,9 +21,9 @@ import (
type GRPCQueryRouter struct {
// routes maps query handlers used in ABCIQuery.
routes map[string]GRPCQueryHandler
// handlerByMessageName maps the request name to the handler. It is a hybrid handler which seamlessly
// hybridHandlers maps the request name to the handler. It is a hybrid handler which seamlessly
// handles both gogo and protov2 messages.
handlerByMessageName map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error
hybridHandlers map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error
// binaryCodec is used to encode/decode binary protobuf messages.
binaryCodec codec.BinaryCodec
// cdc is the gRPC codec used by the router to correctly unmarshal messages.
Expand All @@ -45,8 +43,8 @@ var _ gogogrpc.Server = &GRPCQueryRouter{}
// NewGRPCQueryRouter creates a new GRPCQueryRouter
func NewGRPCQueryRouter() *GRPCQueryRouter {
return &GRPCQueryRouter{
routes: map[string]GRPCQueryHandler{},
handlerByMessageName: map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error{},
routes: map[string]GRPCQueryHandler{},
hybridHandlers: map[string][]func(ctx context.Context, req, resp protoiface.MessageV1) error{},
}
}

Expand Down Expand Up @@ -76,7 +74,7 @@ func (qrt *GRPCQueryRouter) RegisterService(sd *grpc.ServiceDesc, handler interf
if err != nil {
panic(err)
}
err = qrt.registerHandlerByMessageName(sd, method, handler)
err = qrt.registerHybridHandler(sd, method, handler)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -131,27 +129,21 @@ func (qrt *GRPCQueryRouter) registerABCIQueryHandler(sd *grpc.ServiceDesc, metho
return nil
}

func (qrt *GRPCQueryRouter) HandlersByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error {
return qrt.handlerByMessageName[name]
func (qrt *GRPCQueryRouter) HybridHandlerByRequestName(name string) []func(ctx context.Context, req, resp protoiface.MessageV1) error {
return qrt.hybridHandlers[name]
}

func (qrt *GRPCQueryRouter) registerHandlerByMessageName(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error {
func (qrt *GRPCQueryRouter) registerHybridHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error {
// extract message name from method descriptor
methodFullName := protoreflect.FullName(fmt.Sprintf("%s.%s", sd.ServiceName, method.MethodName))
desc, err := proto.HybridResolver.FindDescriptorByName(methodFullName)
inputName, err := protocompat.RequestFullNameFromMethodDesc(sd, method)
if err != nil {
return fmt.Errorf("cannot find method descriptor %s", methodFullName)
}
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
if !ok {
return fmt.Errorf("invalid method descriptor %s", methodFullName)
return err
}
inputName := methodDesc.Input().FullName()
methodHandler, err := protocompat.MakeHybridHandler(qrt.binaryCodec, sd, method, handler)
if err != nil {
return err
}
qrt.handlerByMessageName[string(inputName)] = append(qrt.handlerByMessageName[string(inputName)], methodHandler)
qrt.hybridHandlers[string(inputName)] = append(qrt.hybridHandlers[string(inputName)], methodHandler)
return nil
}

Expand Down
2 changes: 1 addition & 1 deletion baseapp/grpcrouter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func TestGRPCQueryRouter(t *testing.T) {
func TestGRPCRouterHybridHandlers(t *testing.T) {
assertRouterBehaviour := func(helper *baseapp.QueryServiceTestHelper) {
// test getting the handler by name
handlers := helper.GRPCQueryRouter.HandlersByRequestName("testpb.EchoRequest")
handlers := helper.GRPCQueryRouter.HybridHandlerByRequestName("testpb.EchoRequest")
require.NotNil(t, handlers)
require.Len(t, handlers, 1)
handler := handlers[0]
Expand Down
14 changes: 14 additions & 0 deletions baseapp/internal/protocompat/protocompat.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,17 @@ func isProtov2(md grpc.MethodDesc) (isV2Type bool, err error) {
_, _ = md.Handler(nil, nil, pullRequestType, doNotExecute)
return
}

// RequestFullNameFromMethodDesc returns the fully-qualified name of the request message of the provided service's method.
func RequestFullNameFromMethodDesc(sd *grpc.ServiceDesc, method grpc.MethodDesc) (protoreflect.FullName, error) {
methodFullName := protoreflect.FullName(fmt.Sprintf("%s.%s", sd.ServiceName, method.MethodName))
tac0turtle marked this conversation as resolved.
Show resolved Hide resolved
desc, err := gogoproto.HybridResolver.FindDescriptorByName(methodFullName)
if err != nil {
return "", fmt.Errorf("cannot find method descriptor %s", methodFullName)
}
methodDesc, ok := desc.(protoreflect.MethodDescriptor)
if !ok {
return "", fmt.Errorf("invalid method descriptor %s", methodFullName)
}
return methodDesc.Input().FullName(), nil
}
224 changes: 136 additions & 88 deletions baseapp/msg_service_router.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,12 @@ import (
gogogrpc "github.com/cosmos/gogoproto/grpc"
"github.com/cosmos/gogoproto/proto"
"google.golang.org/grpc"
"google.golang.org/protobuf/runtime/protoiface"

errorsmod "cosmossdk.io/errors"

"github.com/cosmos/cosmos-sdk/baseapp/internal/protocompat"
"github.com/cosmos/cosmos-sdk/codec"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
Expand All @@ -27,6 +30,7 @@ type MessageRouter interface {
type MsgServiceRouter struct {
interfaceRegistry codectypes.InterfaceRegistry
routes map[string]MsgServiceHandler
hybridHandlers map[string]func(ctx context.Context, req, resp protoiface.MessageV1) error
circuitBreaker CircuitBreaker
}

Expand All @@ -35,7 +39,8 @@ var _ gogogrpc.Server = &MsgServiceRouter{}
// NewMsgServiceRouter creates a new MsgServiceRouter.
func NewMsgServiceRouter() *MsgServiceRouter {
return &MsgServiceRouter{
routes: map[string]MsgServiceHandler{},
routes: map[string]MsgServiceHandler{},
hybridHandlers: map[string]func(ctx context.Context, req, resp protoiface.MessageV1) error{},
}
}

Expand Down Expand Up @@ -67,113 +72,156 @@ func (msr *MsgServiceRouter) HandlerByTypeURL(typeURL string) MsgServiceHandler
func (msr *MsgServiceRouter) RegisterService(sd *grpc.ServiceDesc, handler interface{}) {
// Adds a top-level query handler based on the gRPC service name.
for _, method := range sd.Methods {
fqMethod := fmt.Sprintf("/%s/%s", sd.ServiceName, method.MethodName)
methodHandler := method.Handler

var requestTypeName string

// NOTE: This is how we pull the concrete request type for each handler for registering in the InterfaceRegistry.
// This approach is maybe a bit hacky, but less hacky than reflecting on the handler object itself.
// We use a no-op interceptor to avoid actually calling into the handler itself.
_, _ = methodHandler(nil, context.Background(), func(i interface{}) error {
msg, ok := i.(sdk.Msg)
if !ok {
// We panic here because there is no other alternative and the app cannot be initialized correctly
// this should only happen if there is a problem with code generation in which case the app won't
// work correctly anyway.
panic(fmt.Errorf("unable to register service method %s: %T does not implement sdk.Msg", fqMethod, i))
}
err := msr.registerMsgServiceHandler(sd, method, handler)
if err != nil {
panic(err)
}
err = msr.registerHybridHandler(sd, method, handler)
if err != nil {
panic(err)
}
}
}

func (msr *MsgServiceRouter) HybridHandlerByMsgName(msgName string) func(ctx context.Context, req, resp protoiface.MessageV1) error {
return msr.hybridHandlers[msgName]
}

requestTypeName = sdk.MsgTypeURL(msg)
return nil
}, noopInterceptor)

// Check that the service Msg fully-qualified method name has already
// been registered (via RegisterInterfaces). If the user registers a
// service without registering according service Msg type, there might be
// some unexpected behavior down the road. Since we can't return an error
// (`Server.RegisterService` interface restriction) we panic (at startup).
reqType, err := msr.interfaceRegistry.Resolve(requestTypeName)
if err != nil || reqType == nil {
panic(
fmt.Errorf(
"type_url %s has not been registered yet. "+
"Before calling RegisterService, you must register all interfaces by calling the `RegisterInterfaces` "+
"method on module.BasicManager. Each module should call `msgservice.RegisterMsgServiceDesc` inside its "+
"`RegisterInterfaces` method with the `_Msg_serviceDesc` generated by proto-gen",
requestTypeName,
),
)
func (msr *MsgServiceRouter) registerHybridHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error {
inputName, err := protocompat.RequestFullNameFromMethodDesc(sd, method)
if err != nil {
return err
}
cdc := codec.NewProtoCodec(msr.interfaceRegistry)
hybridHandler, err := protocompat.MakeHybridHandler(cdc, sd, method, handler)
if err != nil {
return err
}
// if circuit breaker is not nil, then we decorate the hybrid handler with the circuit breaker
if msr.circuitBreaker == nil {
msr.hybridHandlers[string(inputName)] = hybridHandler
return nil
}
// decorate the hybrid handler with the circuit breaker
circuitBreakerHybridHandler := func(ctx context.Context, req, resp protoiface.MessageV1) error {
messageName := codectypes.MsgTypeURL(req)
allowed, err := msr.circuitBreaker.IsAllowed(ctx, messageName)
if err != nil {
return err
}
if !allowed {
return fmt.Errorf("circuit breaker disallows execution of message %s", messageName)
}
return hybridHandler(ctx, req, resp)
}
msr.hybridHandlers[string(inputName)] = circuitBreakerHybridHandler
return nil
}

// Check that each service is only registered once. If a service is
// registered more than once, then we should error. Since we can't
// return an error (`Server.RegisterService` interface restriction) we
// panic (at startup).
_, found := msr.routes[requestTypeName]
if found {
panic(
fmt.Errorf(
"msg service %s has already been registered. Please make sure to only register each service once. "+
"This usually means that there are conflicting modules registering the same msg service",
fqMethod,
),
)
func (msr *MsgServiceRouter) registerMsgServiceHandler(sd *grpc.ServiceDesc, method grpc.MethodDesc, handler interface{}) error {
fqMethod := fmt.Sprintf("/%s/%s", sd.ServiceName, method.MethodName)
methodHandler := method.Handler

var requestTypeName string

// NOTE: This is how we pull the concrete request type for each handler for registering in the InterfaceRegistry.
// This approach is maybe a bit hacky, but less hacky than reflecting on the handler object itself.
// We use a no-op interceptor to avoid actually calling into the handler itself.
_, _ = methodHandler(nil, context.Background(), func(i interface{}) error {
msg, ok := i.(sdk.Msg)
if !ok {
// We panic here because there is no other alternative and the app cannot be initialized correctly
// this should only happen if there is a problem with code generation in which case the app won't
// work correctly anyway.
panic(fmt.Errorf("unable to register service method %s: %T does not implement sdk.Msg", fqMethod, i))
}

msr.routes[requestTypeName] = func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
ctx = ctx.WithEventManager(sdk.NewEventManager())
interceptor := func(goCtx context.Context, _ interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
goCtx = context.WithValue(goCtx, sdk.SdkContextKey, ctx)
return handler(goCtx, msg)
}
requestTypeName = sdk.MsgTypeURL(msg)
return nil
}, noopInterceptor)

// Check that the service Msg fully-qualified method name has already
// been registered (via RegisterInterfaces). If the user registers a
// service without registering according service Msg type, there might be
// some unexpected behavior down the road. Since we can't return an error
// (`Server.RegisterService` interface restriction) we panic (at startup).
reqType, err := msr.interfaceRegistry.Resolve(requestTypeName)
if err != nil || reqType == nil {
return fmt.Errorf(
"type_url %s has not been registered yet. "+
"Before calling RegisterService, you must register all interfaces by calling the `RegisterInterfaces` "+
"method on module.BasicManager. Each module should call `msgservice.RegisterMsgServiceDesc` inside its "+
"`RegisterInterfaces` method with the `_Msg_serviceDesc` generated by proto-gen",
requestTypeName,
)
}

if m, ok := msg.(sdk.HasValidateBasic); ok {
if err := m.ValidateBasic(); err != nil {
return nil, err
}
}
// Check that each service is only registered once. If a service is
// registered more than once, then we should error. Since we can't
// return an error (`Server.RegisterService` interface restriction) we
// panic (at startup).
_, found := msr.routes[requestTypeName]
if found {
return fmt.Errorf(
"msg service %s has already been registered. Please make sure to only register each service once. "+
"This usually means that there are conflicting modules registering the same msg service",
fqMethod,
)
}

if msr.circuitBreaker != nil {
msgURL := sdk.MsgTypeURL(msg)
isAllowed, err := msr.circuitBreaker.IsAllowed(ctx, msgURL)
if err != nil {
return nil, err
}
msr.routes[requestTypeName] = func(ctx sdk.Context, msg sdk.Msg) (*sdk.Result, error) {
ctx = ctx.WithEventManager(sdk.NewEventManager())
interceptor := func(goCtx context.Context, _ interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
goCtx = context.WithValue(goCtx, sdk.SdkContextKey, ctx)
return handler(goCtx, msg)
}

if !isAllowed {
return nil, fmt.Errorf("circuit breaker disables execution of this message: %s", msgURL)
}
if m, ok := msg.(sdk.HasValidateBasic); ok {
if err := m.ValidateBasic(); err != nil {
return nil, err
}
}

// Call the method handler from the service description with the handler object.
// We don't do any decoding here because the decoding was already done.
res, err := methodHandler(handler, ctx, noopDecoder, interceptor)
if msr.circuitBreaker != nil {
msgURL := sdk.MsgTypeURL(msg)
isAllowed, err := msr.circuitBreaker.IsAllowed(ctx, msgURL)
if err != nil {
return nil, err
}

resMsg, ok := res.(proto.Message)
if !ok {
return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting proto.Message, got %T", resMsg)
if !isAllowed {
return nil, fmt.Errorf("circuit breaker disables execution of this message: %s", msgURL)
}
}

anyResp, err := codectypes.NewAnyWithValue(resMsg)
if err != nil {
return nil, err
}
// Call the method handler from the service description with the handler object.
// We don't do any decoding here because the decoding was already done.
res, err := methodHandler(handler, ctx, noopDecoder, interceptor)
if err != nil {
return nil, err
}

var events []abci.Event
if evtMgr := ctx.EventManager(); evtMgr != nil {
events = evtMgr.ABCIEvents()
}
resMsg, ok := res.(proto.Message)
if !ok {
return nil, errorsmod.Wrapf(sdkerrors.ErrInvalidType, "Expecting proto.Message, got %T", resMsg)
}

return &sdk.Result{
Events: events,
MsgResponses: []*codectypes.Any{anyResp},
}, nil
anyResp, err := codectypes.NewAnyWithValue(resMsg)
if err != nil {
return nil, err
}

var events []abci.Event
if evtMgr := ctx.EventManager(); evtMgr != nil {
events = evtMgr.ABCIEvents()
}

return &sdk.Result{
Events: events,
MsgResponses: []*codectypes.Any{anyResp},
}, nil
}
return nil
}

// SetInterfaceRegistry sets the interface registry for the router.
Expand Down
Loading
Loading