From 48996cfc901248508cf6f2ac486ddde25d0981f6 Mon Sep 17 00:00:00 2001 From: Bright Chen Date: Mon, 26 Aug 2024 10:27:21 +0800 Subject: [PATCH] Support rpc protobuf message factory interface (#2718) * Support rpc protobuf message factory interface * Update cn/server.md --- docs/cn/server.md | 28 ++++++++ src/brpc/policy/baidu_rpc_protocol.cpp | 99 +++++++++++++++++--------- src/brpc/rpc_pb_message_factory.cpp | 51 +++++++++++++ src/brpc/rpc_pb_message_factory.h | 54 ++++++++++++++ src/brpc/server.cpp | 6 +- src/brpc/server.h | 10 +++ test/brpc_channel_unittest.cpp | 27 ++++--- test/brpc_server_unittest.cpp | 61 ++++++++++++++++ 8 files changed, 289 insertions(+), 47 deletions(-) create mode 100644 src/brpc/rpc_pb_message_factory.cpp create mode 100644 src/brpc/rpc_pb_message_factory.h diff --git a/docs/cn/server.md b/docs/cn/server.md index 070adf978e..5e3fc42a6f 100644 --- a/docs/cn/server.md +++ b/docs/cn/server.md @@ -1013,6 +1013,34 @@ public: ... ``` +## RPC Protobuf message factory + +Server默认使用`DefaultRpcPBMessageFactory`。它是一个简单的工厂类,通过`new`来创建请求/响应message和`delete`来销毁请求/响应message。 + +如果用户希望自定义创建销毁机制,可以实现`RpcPBMessages`(请求/响应message的封装)和`RpcPBMessageFactory`(工厂类),并通过`ServerOptions.rpc_pb_message_factory`。 + +接口如下: + +```c++ +// Inherit this class to customize rpc protobuf messages, +// include request and response. +class RpcPBMessages { +public: + virtual ~RpcPBMessages() = default; + virtual google::protobuf::Message* Request() = 0; + virtual google::protobuf::Message* Response() = 0; +}; + +// Factory to manage `RpcPBMessages'. +class RpcPBMessageFactory { +public: + virtual ~RpcPBMessageFactory() = default; + virtual RpcPBMessages* Get(const ::google::protobuf::Service& service, + const ::google::protobuf::MethodDescriptor& method) = 0; + virtual void Return(RpcPBMessages* protobuf_message) = 0; +}; +``` + # FAQ ### Q: Fail to write into fd=1865 SocketId=8905@10.208.245.43:54742@8230: Got EOF是什么意思 diff --git a/src/brpc/policy/baidu_rpc_protocol.cpp b/src/brpc/policy/baidu_rpc_protocol.cpp index 6ce76467a3..0fb439a82d 100644 --- a/src/brpc/policy/baidu_rpc_protocol.cpp +++ b/src/brpc/policy/baidu_rpc_protocol.cpp @@ -24,6 +24,7 @@ #include "butil/time.h" #include "butil/iobuf.h" // butil::IOBuf #include "butil/raw_pack.h" // RawPacker RawUnpacker +#include "butil/memory/scope_guard.h" #include "brpc/controller.h" // Controller #include "brpc/socket.h" // Socket #include "brpc/server.h" // Server @@ -31,6 +32,7 @@ #include "brpc/compress.h" // ParseFromCompressedData #include "brpc/stream_impl.h" #include "brpc/rpc_dump.h" // SampledRequest +#include "brpc/rpc_pb_message_factory.h" #include "brpc/policy/baidu_rpc_meta.pb.h" // RpcRequestMeta #include "brpc/policy/baidu_rpc_protocol.h" #include "brpc/policy/most_common_message.h" @@ -157,11 +159,34 @@ static bool SerializeResponse(const google::protobuf::Message& res, return true; } +namespace { +struct BaiduProxyPBMessages : public RpcPBMessages { + static BaiduProxyPBMessages* Get() { + return butil::get_object(); + } + + static void Return(BaiduProxyPBMessages* messages) { + messages->Clear(); + butil::return_object(messages); + } + + void Clear() { + request.Clear(); + response.Clear(); + } + + ::google::protobuf::Message* Request() override { return &request; } + ::google::protobuf::Message* Response() override { return &response; } + + SerializedRequest request; + SerializedResponse response; +}; +} + // Used by UT, can't be static. void SendRpcResponse(int64_t correlation_id, - Controller* cntl, - const google::protobuf::Message* req, - const google::protobuf::Message* res, + Controller* cntl, + RpcPBMessages* messages, const Server* server, MethodStatus* method_status, int64_t received_us) { @@ -172,13 +197,24 @@ void SendRpcResponse(int64_t correlation_id, } Socket* sock = accessor.get_sending_socket(); - std::unique_ptr recycle_req(req); - std::unique_ptr recycle_res(res); - std::unique_ptr recycle_cntl(cntl); ConcurrencyRemover concurrency_remover(method_status, cntl, received_us); - ClosureGuard guard(brpc::NewCallback(cntl, &Controller::CallAfterRpcResp, req, res)); + auto messages_guard = butil::MakeScopeGuard([server, messages] { + if (NULL == messages) { + return; + } + if (NULL != server->options().baidu_master_service) { + BaiduProxyPBMessages::Return(static_cast(messages)); + } else { + server->options().rpc_pb_message_factory->Return(messages); + } + }); + + const google::protobuf::Message* req = NULL == messages ? NULL : messages->Request(); + const google::protobuf::Message* res = NULL == messages ? NULL : messages->Response(); + ClosureGuard guard(brpc::NewCallback( + cntl, &Controller::CallAfterRpcResp, req, res)); StreamId response_stream_id = accessor.response_stream(); @@ -375,8 +411,7 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { return; } - std::unique_ptr req; - std::unique_ptr res; + RpcPBMessages* messages = NULL; ServerPrivateAccessor server_accessor(server); ControllerPrivateAccessor accessor(cntl.get()); @@ -496,13 +531,10 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { span->ResetServerSpanName(sampled_request->meta.method_name()); } - auto serialized_request = (SerializedRequest*) - svc->GetRequestPrototype(NULL).New(); - req.reset(serialized_request); - res.reset(svc->GetResponsePrototype(NULL).New()); - - msg->payload.cutn(&serialized_request->serialized_data(), - req_size - meta.attachment_size()); + messages = BaiduProxyPBMessages::Get(); + msg->payload.cutn( + &((SerializedRequest*)messages->Request())->serialized_data(), + req_size - meta.attachment_size()); if (!msg->payload.empty()) { cntl->request_attachment().swap(msg->payload); } @@ -568,26 +600,25 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { } auto req_cmp_type = static_cast(meta.compress_type()); - req.reset(svc->GetRequestPrototype(method).New()); - if (!ParseFromCompressedData(req_buf, req.get(), req_cmp_type)) { + messages = server->options().rpc_pb_message_factory->Get(*svc, *method); + if (!ParseFromCompressedData(req_buf, messages->Request(), req_cmp_type)) { cntl->SetFailed(EREQUEST, "Fail to parse request message, " "CompressType=%s, request_size=%d", CompressTypeToCStr(req_cmp_type), req_size); + server->options().rpc_pb_message_factory->Return(messages); break; } - - res.reset(svc->GetResponsePrototype(method).New()); req_buf.clear(); } // `socket' will be held until response has been sent google::protobuf::Closure* done = ::brpc::NewCallback< - int64_t, Controller*, const google::protobuf::Message*, - const google::protobuf::Message*, const Server*, - MethodStatus*, int64_t>( - &SendRpcResponse, meta.correlation_id(), cntl.get(), - req.get(), res.get(), server, - method_status, msg->received_us()); + int64_t, Controller*, RpcPBMessages*, + const Server*, MethodStatus*, int64_t>(&SendRpcResponse, + meta.correlation_id(), + cntl.get(), messages, + server, method_status, + msg->received_us()); // optional, just release resource ASAP msg.reset(); @@ -598,24 +629,28 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { } if (!FLAGS_usercode_in_pthread) { return svc->CallMethod(method, cntl.release(), - req.release(), res.release(), done); + messages->Request(), + messages->Response(), done); } if (BeginRunningUserCode()) { svc->CallMethod(method, cntl.release(), - req.release(), res.release(), done); + messages->Request(), + messages->Response(), done); return EndRunningUserCodeInPlace(); } else { return EndRunningCallMethodInPool( svc, method, cntl.release(), - req.release(), res.release(), done); + messages->Request(), + messages->Response(), done); } } while (false); // `cntl', `req' and `res' will be deleted inside `SendRpcResponse' // `socket' will be held until response has been sent - SendRpcResponse(meta.correlation_id(), cntl.release(), - req.release(), res.release(), server, - method_status, msg->received_us()); + SendRpcResponse(meta.correlation_id(), + cntl.release(), messages, + server, method_status, + msg->received_us()); } bool VerifyRpcRequest(const InputMessageBase* msg_base) { diff --git a/src/brpc/rpc_pb_message_factory.cpp b/src/brpc/rpc_pb_message_factory.cpp new file mode 100644 index 0000000000..828a289d8d --- /dev/null +++ b/src/brpc/rpc_pb_message_factory.cpp @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "brpc/rpc_pb_message_factory.h" +#include "butil/object_pool.h" + +namespace brpc { + +struct DefaultRpcPBMessages : public RpcPBMessages { + DefaultRpcPBMessages() : request(NULL), response(NULL) {} + ::google::protobuf::Message* Request() override { return request; } + ::google::protobuf::Message* Response() override { return response; } + + ::google::protobuf::Message* request; + ::google::protobuf::Message* response; +}; + + +RpcPBMessages* DefaultRpcPBMessageFactory::Get( + const ::google::protobuf::Service& service, + const ::google::protobuf::MethodDescriptor& method) { + auto messages = butil::get_object(); + messages->request = service.GetRequestPrototype(&method).New(); + messages->response = service.GetResponsePrototype(&method).New(); + return messages; +} + +void DefaultRpcPBMessageFactory::Return(RpcPBMessages* messages) { + auto default_messages = static_cast(messages); + delete default_messages->request; + delete default_messages->response; + default_messages->request = NULL; + default_messages->response = NULL; + butil::return_object(default_messages); +} + +} // namespace brpc \ No newline at end of file diff --git a/src/brpc/rpc_pb_message_factory.h b/src/brpc/rpc_pb_message_factory.h new file mode 100644 index 0000000000..0da0ff2a67 --- /dev/null +++ b/src/brpc/rpc_pb_message_factory.h @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef BRPC_RPC_PB_MESSAGE_FACTORY_H +#define BRPC_RPC_PB_MESSAGE_FACTORY_H + +#include +#include +#include + +namespace brpc { + +// Inherit this class to customize rpc protobuf messages, +// include request and response. +class RpcPBMessages { +public: + virtual ~RpcPBMessages() = default; + virtual google::protobuf::Message* Request() = 0; + virtual google::protobuf::Message* Response() = 0; +}; + +// Factory to manage `RpcPBMessages'. +class RpcPBMessageFactory { +public: + virtual ~RpcPBMessageFactory() = default; + virtual RpcPBMessages* Get(const ::google::protobuf::Service& service, + const ::google::protobuf::MethodDescriptor& method) = 0; + virtual void Return(RpcPBMessages* protobuf_message) = 0; +}; + +class DefaultRpcPBMessageFactory : public RpcPBMessageFactory { +public: + RpcPBMessages* Get(const ::google::protobuf::Service& service, + const ::google::protobuf::MethodDescriptor& method) override; + void Return(RpcPBMessages* messages) override; +}; + +} // namespace brpc + +#endif //BRPC_RPC_PB_MESSAGE_FACTORY_H diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp index 00cdec79af..740873f126 100644 --- a/src/brpc/server.cpp +++ b/src/brpc/server.cpp @@ -151,7 +151,8 @@ ServerOptions::ServerOptions() , health_reporter(NULL) , rtmp_service(NULL) , redis_service(NULL) - , bthread_tag(BTHREAD_TAG_INVALID) { + , bthread_tag(BTHREAD_TAG_INVALID) + , rpc_pb_message_factory(new DefaultRpcPBMessageFactory()) { if (s_ncore > 0) { num_threads = s_ncore + 1; } @@ -449,6 +450,9 @@ Server::~Server() { delete _options.http_master_service; _options.http_master_service = NULL; + delete _options.rpc_pb_message_factory; + _options.rpc_pb_message_factory = NULL; + delete _am; _am = NULL; delete _internal_am; diff --git a/src/brpc/server.h b/src/brpc/server.h index fdcba68f77..d65e13f0b2 100644 --- a/src/brpc/server.h +++ b/src/brpc/server.h @@ -44,6 +44,7 @@ #include "brpc/interceptor.h" #include "brpc/concurrency_limiter.h" #include "brpc/baidu_master_service.h" +#include "brpc/rpc_pb_message_factory.h" namespace brpc { @@ -277,6 +278,15 @@ struct ServerOptions { // Default: BTHREAD_TAG_DEFAULT bthread_tag_t bthread_tag; + // [CAUTION] This option is for implementing specialized rpc protobuf + // message factory, most users don't need it. Don't change this option + // unless you fully understand the description below. + // If this option is set, all baidu-std rpc request message and response + // message will be created by this factory. + // + // Owned by Server and deleted in server's destructor. + RpcPBMessageFactory* rpc_pb_message_factory; + private: // SSLOptions is large and not often used, allocate it on heap to // prevent ServerOptions from being bloated in most cases. diff --git a/test/brpc_channel_unittest.cpp b/test/brpc_channel_unittest.cpp index d43a0f4b95..6c189f18ed 100644 --- a/test/brpc_channel_unittest.cpp +++ b/test/brpc_channel_unittest.cpp @@ -53,10 +53,11 @@ DECLARE_int32(max_connection_pool_size); class Server; class MethodStatus; namespace policy { -void SendRpcResponse(int64_t correlation_id, Controller* cntl, - const google::protobuf::Message* req, - const google::protobuf::Message* res, - const Server* server_raw, MethodStatus *, int64_t); +void SendRpcResponse(int64_t correlation_id, + Controller* cntl, + RpcPBMessages* messages, + const Server* server_raw, + MethodStatus *, int64_t); } // policy } // brpc @@ -255,8 +256,10 @@ class ChannelTest : public ::testing::Test{ ASSERT_EQ(ts->_svc.descriptor()->full_name(), req_meta.service_name()); const google::protobuf::MethodDescriptor* method = ts->_svc.descriptor()->FindMethodByName(req_meta.method_name()); - google::protobuf::Message* req = - ts->_svc.GetRequestPrototype(method).New(); + brpc::RpcPBMessages* messages = + ts->_dummy.options().rpc_pb_message_factory->Get(ts->_svc, *method); + google::protobuf::Message* req = messages->Request(); + google::protobuf::Message* res = messages->Response(); if (meta.attachment_size() != 0) { butil::IOBuf req_buf; msg->payload.cutn(&req_buf, msg->payload.size() - meta.attachment_size()); @@ -271,18 +274,14 @@ class ChannelTest : public ::testing::Test{ cntl->_current_call.sending_sock.reset(ptr.release()); cntl->_server = &ts->_dummy; - google::protobuf::Message* res = - ts->_svc.GetResponsePrototype(method).New(); google::protobuf::Closure* done = brpc::NewCallback< int64_t, brpc::Controller*, - const google::protobuf::Message*, - const google::protobuf::Message*, + brpc::RpcPBMessages*, const brpc::Server*, - brpc::MethodStatus*, int64_t>( - &brpc::policy::SendRpcResponse, - meta.correlation_id(), cntl, req, res, - &ts->_dummy, NULL, -1); + brpc::MethodStatus*, int64_t>(&brpc::policy::SendRpcResponse, + meta.correlation_id(), cntl, + messages, &ts->_dummy, NULL, -1); ts->_svc.CallMethod(method, cntl, req, res, done); } diff --git a/test/brpc_server_unittest.cpp b/test/brpc_server_unittest.cpp index cc98f11154..5f06887a52 100644 --- a/test/brpc_server_unittest.cpp +++ b/test/brpc_server_unittest.cpp @@ -29,6 +29,7 @@ #include "butil/fd_guard.h" #include "butil/files/scoped_file.h" #include "brpc/socket.h" +#include "butil/object_pool.h" #include "brpc/builtin/version_service.h" #include "brpc/builtin/health_service.h" #include "brpc/builtin/list_service.h" @@ -1767,4 +1768,64 @@ TEST_F(ServerTest, generic_call) { ASSERT_EQ(0, server.Join()); } +struct DefaultRpcPBMessages : public brpc::RpcPBMessages { + DefaultRpcPBMessages() : request(NULL), response(NULL) {} + ::google::protobuf::Message* Request() override { return request; } + ::google::protobuf::Message* Response() override { return response; } + + ::google::protobuf::Message* request; + ::google::protobuf::Message* response; +}; + +class TestRpcPBMessageFactory : public brpc::RpcPBMessageFactory { +public: + brpc::RpcPBMessages* Get(const google::protobuf::Service& service, + const google::protobuf::MethodDescriptor& method) override { + auto messages = butil::get_object(); + auto request = butil::get_object(); + auto response = butil::get_object(); + request->clear_message(); + response->clear_message(); + messages->request = request; + messages->response = response; + return messages; + } + + void Return(brpc::RpcPBMessages* messages) override { + auto test_messages = static_cast(messages); + butil::return_object(static_cast(test_messages->request)); + butil::return_object(static_cast(test_messages->response)); + test_messages->request = NULL; + test_messages->response = NULL; + butil::return_object(test_messages); + } +}; + +TEST_F(ServerTest, rpc_pb_message_factory) { + butil::EndPoint ep; + ASSERT_EQ(0, str2endpoint("127.0.0.1:8613", &ep)); + brpc::Server server; + EchoServiceImpl service; + ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE)); + brpc::ServerOptions opt; + opt.rpc_pb_message_factory = new TestRpcPBMessageFactory; + ASSERT_EQ(0, server.Start(ep, &opt)); + + brpc::Channel chan; + brpc::ChannelOptions copt; + copt.protocol = "baidu_std"; + ASSERT_EQ(0, chan.Init(ep, &copt)); + brpc::Controller cntl; + test::EchoRequest req; + test::EchoResponse res; + req.set_message(EXP_REQUEST); + test::EchoService_Stub stub(&chan); + stub.Echo(&cntl, &req, &res, NULL); + ASSERT_FALSE(cntl.Failed()) << cntl.ErrorText(); + ASSERT_EQ(EXP_RESPONSE, res.message()); + + ASSERT_EQ(0, server.Stop(0)); + ASSERT_EQ(0, server.Join()); +} + } //namespace