Skip to content

Commit

Permalink
feat: RPCHandler passed to register_x functions can now have a `req…
Browse files Browse the repository at this point in the history
…uest_message_factory()` member function that will be used to allocate the initial request message for unary, server-streaming rpcs
  • Loading branch information
Tradias committed Dec 8, 2024
1 parent 6ad992a commit a0bbfa0
Show file tree
Hide file tree
Showing 14 changed files with 366 additions and 83 deletions.
12 changes: 6 additions & 6 deletions src/agrpc/detail/register_callback_rpc_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ struct RegisterCallbackRPCHandlerOperation
using ServerRPCWithRequest = detail::ServerRPCWithRequest<ServerRPC>;
using ServerRPCPtr = agrpc::ServerRPCPtr<ServerRPC>;

struct ServerRPCAllocation : ServerRPCWithRequest
struct ServerRPCAllocation
: detail::RequestMessageFactoryServerRPCMixinT<ServerRPCWithRequest, ServerRPC, RPCHandler>
{
ServerRPCAllocation(const ServerRPCExecutor& executor, RegisterCallbackRPCHandlerOperation& self)
: ServerRPCWithRequest(executor), self_(self)
: ServerRPCAllocation::RequestMessageFactoryMixin(self.rpc_handler(), executor), self_(self)
{
}

Expand All @@ -57,11 +58,11 @@ struct RegisterCallbackRPCHandlerOperation
if (ok)
{
self_.notify_when_done_work_started();
self_.initiate_next();
AGRPC_TRY
{
auto& starter = ptr_.server_rpc_;
starter->invoke(self_.rpc_handler(), static_cast<ServerRPCPtr&&>(ptr_));
self_.initiate_next();
auto& starter = *static_cast<ServerRPCAllocation*>(ptr_.server_rpc_);
starter.invoke(self_.rpc_handler(), static_cast<ServerRPCPtr&&>(ptr_));
}
AGRPC_CATCH(...)
{
Expand Down Expand Up @@ -132,7 +133,6 @@ struct RegisterCallbackRPCHandlerOperation
: Base(executor, service, static_cast<RPCHandler&&>(rpc_handler), static_cast<Ch&&>(completion_handler),
&detail::register_rpc_handler_asio_do_complete<RegisterCallbackRPCHandlerOperation>)
{
initiate();
}

void initiate()
Expand Down
3 changes: 1 addition & 2 deletions src/agrpc/detail/register_coroutine_rpc_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ struct RegisterCoroutineRPCHandlerOperation
: Base(executor, service, static_cast<RPCHandler&&>(rpc_handler), static_cast<Ch&&>(completion_handler),
&detail::register_rpc_handler_asio_do_complete<Type>)
{
initiate();
}

void initiate()
Expand All @@ -74,7 +73,7 @@ struct RegisterCoroutineRPCHandlerOperation
{
auto& self = static_cast<Type&>(g.get().self_);
auto rpc = detail::ServerRPCContextBaseAccess::construct<ServerRPC>(self.get_executor());
detail::ServerRPCStarterT<ServerRPC, Args...> starter;
detail::RequestMessageFactoryServerRPCStarter<ServerRPC, RPCHandler, Args...> starter{self.rpc_handler()};
if (!co_await starter.start(rpc, self.service(), self.completion_token()))
{
co_return;
Expand Down
9 changes: 5 additions & 4 deletions src/agrpc/detail/register_rpc_handler_asio_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ class RegisterRPCHandlerOperationAsioBase
using typename Base::Service;
using Executor = detail::AssociatedExecutorT<CompletionHandlerT, ServerRPCExecutor>;
using Allocator = detail::AssociatedAllocatorT<CompletionHandlerT>;
using Starter = detail::ServerRPCStarterT<ServerRPC>;
using RefCountGuard = detail::ScopeGuard<Decrementer>;

template <class Ch>
Expand Down Expand Up @@ -98,10 +97,12 @@ struct RegisterRPCHandlerInitiator
RPCHandler&& rpc_handler) const
{
const auto allocator = asio::get_associated_allocator(completion_handler);
detail::allocate<Operation<ServerRPC, detail::RemoveCrefT<RPCHandler>, detail::RemoveCrefT<CompletionHandler>>>(
auto op = detail::allocate<
Operation<ServerRPC, detail::RemoveCrefT<RPCHandler>, detail::RemoveCrefT<CompletionHandler>>>(
allocator, executor, service_, static_cast<RPCHandler&&>(rpc_handler),
static_cast<CompletionHandler&&>(completion_handler))
.release();
static_cast<CompletionHandler&&>(completion_handler));
(*op).initiate();
op.release();
}

detail::ServerRPCServiceT<ServerRPC>& service_;
Expand Down
35 changes: 24 additions & 11 deletions src/agrpc/detail/register_sender_rpc_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,15 +149,15 @@ struct RPCHandlerOperationWaitForDone
};

template <class ServerRPC, class RPCHandler, class StopToken, class Allocator>
void create_and_start_rpc_handler_operation(
std::optional<std::exception_ptr> create_and_start_rpc_handler_operation(
RegisterRPCHandlerSenderOperationBase<ServerRPC, RPCHandler, StopToken>& operation, const Allocator& allocator);

template <class ServerRPC, class RPCHandler, class StopToken, class Allocator>
struct RPCHandlerOperation
{
using Service = detail::ServerRPCServiceT<ServerRPC>;
using Traits = typename ServerRPC::Traits;
using Starter = detail::ServerRPCStarterT<ServerRPC>;
using Starter = detail::RequestMessageFactoryServerRPCStarter<ServerRPC, RPCHandler>;
using RPCHandlerInvokeResult = detail::RPCHandlerInvokeResultT<Starter&, RPCHandler&, ServerRPC&>;
using RegisterRPCHandlerSenderOperationBase =
detail::RegisterRPCHandlerSenderOperationBase<ServerRPC, RPCHandler, StopToken>;
Expand All @@ -184,7 +184,12 @@ struct RPCHandlerOperation
base.set_error(static_cast<std::exception_ptr&&>(*exception_ptr));
return;
}
detail::create_and_start_rpc_handler_operation(base, op.get_allocator());
if (auto exception_ptr = detail::create_and_start_rpc_handler_operation(base, op.get_allocator()))
{
op.rpc_.cancel();
base.set_error(static_cast<std::exception_ptr&&>(*exception_ptr));
return;
}
op.start_rpc_handler_operation_state();
guard.release();
}
Expand Down Expand Up @@ -259,7 +264,7 @@ struct RPCHandlerOperation
using OperationState = std::variant<StartOperationState, FinishOperationState, WaitForDoneOperationState>;

explicit RPCHandlerOperation(RegisterRPCHandlerSenderOperationBase& operation, const Allocator& allocator)
: impl1_(operation),
: impl1_(operation, operation.rpc_handler()),
rpc_(detail::ServerRPCContextBaseAccess::construct<ServerRPC>(operation.get_executor())),
impl2_(detail::SecondThenVariadic{}, allocator, std::in_place_type<StartOperationState>,
detail::InplaceWithFunction{},
Expand Down Expand Up @@ -338,17 +343,22 @@ struct RPCHandlerOperation
};

template <class ServerRPC, class RPCHandler, class StopToken, class Allocator>
void create_and_start_rpc_handler_operation(
std::optional<std::exception_ptr> create_and_start_rpc_handler_operation(
RegisterRPCHandlerSenderOperationBase<ServerRPC, RPCHandler, StopToken>& operation, const Allocator& allocator)
{
if AGRPC_UNLIKELY (operation.is_stopped())
{
return;
return {};
}
AGRPC_TRY
{
using RPCHandlerOperation = detail::RPCHandlerOperation<ServerRPC, RPCHandler, StopToken, Allocator>;
auto rpc_handler_operation_guard = detail::allocate<RPCHandlerOperation>(allocator, operation, allocator);
(*rpc_handler_operation_guard).start();
rpc_handler_operation_guard.release();
return {};
}
using RPCHandlerOperation = detail::RPCHandlerOperation<ServerRPC, RPCHandler, StopToken, Allocator>;
auto rpc_handler_operation_guard = detail::allocate<RPCHandlerOperation>(allocator, operation, allocator);
(*rpc_handler_operation_guard).start();
rpc_handler_operation_guard.release();
AGRPC_CATCH(...) { return std::current_exception(); }
}

template <class ServerRPC, class RPCHandler, class Receiver>
Expand Down Expand Up @@ -379,7 +389,10 @@ class RPCHandlerSenderOperation
return;
}
this->stop_context_.emplace(std::move(stop_token));
detail::create_and_start_rpc_handler_operation(*this, get_allocator());
if (auto ep = detail::create_and_start_rpc_handler_operation(*this, get_allocator()))
{
exec::set_error(static_cast<Receiver&&>(receiver_), static_cast<std::exception_ptr&&>(*ep));
}
}

#ifdef AGRPC_STDEXEC
Expand Down
4 changes: 1 addition & 3 deletions src/agrpc/detail/register_yield_rpc_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,13 @@ struct RegisterYieldRPCHandlerOperation
using typename Base::RefCountGuard;
using typename Base::ServerRPCExecutor;
using typename Base::Service;
using typename Base::Starter;

template <class Ch>
RegisterYieldRPCHandlerOperation(const ServerRPCExecutor& executor, Service& service, RPCHandler&& rpc_handler,
Ch&& completion_handler)
: Base(executor, service, static_cast<RPCHandler&&>(rpc_handler), static_cast<Ch&&>(completion_handler),
&detail::register_rpc_handler_asio_do_complete<RegisterYieldRPCHandlerOperation>)
{
initiate();
}

void initiate()
Expand All @@ -88,7 +86,7 @@ struct RegisterYieldRPCHandlerOperation
void perform_request_and_repeat(const Yield& yield)
{
auto rpc = detail::ServerRPCContextBaseAccess::construct<ServerRPC>(this->get_executor());
Starter starter;
detail::RequestMessageFactoryServerRPCStarter<ServerRPC, RPCHandler> starter{this->rpc_handler()};
if (!starter.start(rpc, this->service(), use_yield(yield)))
{
return;
Expand Down
132 changes: 122 additions & 10 deletions src/agrpc/detail/server_rpc_starter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@ struct ServerRPCStarter
using Responder = std::remove_reference_t<decltype(ServerRPCContextBaseAccess::responder(rpc))>;
return detail::async_initiate_sender_implementation(
RPCExecutorBaseAccess::grpc_context(rpc),
detail::ServerRequestSenderInitiation<RequestRPC>{service, request_},
detail::ServerRequestSenderInitiation<RequestRPC>{service, *request_},
detail::ServerRequestSenderImplementation<Responder, TraitsT::NOTIFY_WHEN_DONE>{rpc},
static_cast<CompletionToken&&>(token));
}

template <class Handler, class RPC, class... AppendedArgs>
decltype(auto) invoke(Handler&& handler, PrependedArgs&&... prepend, RPC&& rpc, AppendedArgs&&... append)
template <class RPCHandler, class RPC, class... AppendedArgs>
decltype(auto) invoke(RPCHandler&& handler, PrependedArgs&&... prepend, RPC&& rpc, AppendedArgs&&... append)
{
return static_cast<Handler&&>(handler)(static_cast<PrependedArgs&&>(prepend)..., static_cast<RPC&&>(rpc),
request_, static_cast<AppendedArgs&&>(append)...);
return static_cast<RPCHandler&&>(handler)(static_cast<PrependedArgs&&>(prepend)..., static_cast<RPC&&>(rpc),
*request_, static_cast<AppendedArgs&&>(append)...);
}

Request request_;
Request* request_;
};

template <class Request, class... PrependedArgs>
Expand All @@ -66,18 +66,130 @@ struct ServerRPCStarter<Request, false, PrependedArgs...>
static_cast<CompletionToken&&>(token));
}

template <class Handler, class RPC, class... AppendedArgs>
decltype(auto) invoke(Handler&& handler, PrependedArgs&&... prepend, RPC&& rpc, AppendedArgs&&... append)
template <class RPCHandler, class RPC, class... AppendedArgs>
decltype(auto) invoke(RPCHandler&& handler, PrependedArgs&&... prepend, RPC&& rpc, AppendedArgs&&... append)
{
return static_cast<Handler&&>(handler)(static_cast<PrependedArgs&&>(prepend)..., static_cast<RPC&&>(rpc),
static_cast<AppendedArgs&&>(append)...);
return static_cast<RPCHandler&&>(handler)(static_cast<PrependedArgs&&>(prepend)..., static_cast<RPC&&>(rpc),
static_cast<AppendedArgs&&>(append)...);
}
};

template <class ServerRPC, class... PrependedArgs>
using ServerRPCStarterT = detail::ServerRPCStarter<typename ServerRPC::Request,
detail::has_initial_request(ServerRPC::TYPE), PrependedArgs...>;

template <class Request>
struct DefaultRequestMessageFactory
{
template <class>
Request& create()
{
return request_;
}

Request request_;
};

template <class RequestT, class RPCHandler, class = void>
struct RequestMessageFactoryBuilder
{
static constexpr bool IS_DEFAULT = true;

using Request = RequestT;
using Type = DefaultRequestMessageFactory<RequestT>;

static Type build(RPCHandler&) { return Type{}; }
};

template <class RequestT, class RPCHandler>
struct RequestMessageFactoryBuilder<RequestT, RPCHandler,
decltype((void)std::declval<RPCHandler&>().request_message_factory())>
{
static constexpr bool IS_DEFAULT = false;

using Request = RequestT;
using Type = decltype(std::declval<RPCHandler&>().request_message_factory());

static Type build(RPCHandler& rpc_handler) { return rpc_handler.request_message_factory(); }
};

template <class Request, class RequestMessageFactory, class = void>
inline constexpr bool REQUEST_MESSAGE_FACTORY_HAS_DESTROY = false;

template <class Request, class RequestMessageFactory>
inline constexpr bool REQUEST_MESSAGE_FACTORY_HAS_DESTROY<Request, RequestMessageFactory,
decltype((void)std::declval<RequestMessageFactory&>().destroy(
std::declval<Request&>()))> = true;

template <class Base, class RequestMessageFactoryBuilder, bool HasInitialRequest>
struct RequestMessageFactoryMixin : Base
{
using RequestMessageFactory = typename RequestMessageFactoryBuilder::Type;
using Request = typename RequestMessageFactoryBuilder::Request;

template <class RPCHandler, class... Args>
explicit RequestMessageFactoryMixin(RPCHandler& rpc_handler, Args&&... args)
: Base{static_cast<Args&&>(args)...}, request_factory_(RequestMessageFactoryBuilder::build(rpc_handler))
{
this->request_ = &request_factory_.template create<Request>();
}

RequestMessageFactoryMixin(const RequestMessageFactoryMixin& other) = delete;
RequestMessageFactoryMixin(RequestMessageFactoryMixin&& other) = delete;

~RequestMessageFactoryMixin()
{
if constexpr (REQUEST_MESSAGE_FACTORY_HAS_DESTROY<Request, RequestMessageFactory>)
{
static_assert(noexcept(request_factory_.destroy(*this->request_)),
"Request message factory `destroy(Request&)` must be noexcept");
request_factory_.destroy(*this->request_);
}
}

RequestMessageFactoryMixin& operator=(const RequestMessageFactoryMixin& other) = delete;
RequestMessageFactoryMixin& operator=(RequestMessageFactoryMixin&& other) = delete;

template <class... Args>
decltype(auto) invoke(Args&&... args)
{
if constexpr (RequestMessageFactoryBuilder::IS_DEFAULT)
{
return Base::invoke(static_cast<Args&&>(args)...);
}
else
{
return Base::invoke(static_cast<Args&&>(args)..., request_factory_);
}
}

RequestMessageFactory request_factory_;
};

template <class Base, class RequestMessageFactoryBuilder>
struct RequestMessageFactoryMixin<Base, RequestMessageFactoryBuilder, false> : Base
{
template <class RPCHandler, class... Args>
explicit RequestMessageFactoryMixin(RPCHandler&, Args&&... args) : Base{static_cast<Args&&>(args)...}
{
}

RequestMessageFactoryMixin(const RequestMessageFactoryMixin& other) = delete;
RequestMessageFactoryMixin(RequestMessageFactoryMixin&& other) = delete;
RequestMessageFactoryMixin& operator=(const RequestMessageFactoryMixin& other) = delete;
RequestMessageFactoryMixin& operator=(RequestMessageFactoryMixin&& other) = delete;
};

template <class Base, class ServerRPC, class RPCHandler>
using RequestMessageFactoryServerRPCMixinT =
detail::RequestMessageFactoryMixin<Base,
detail::RequestMessageFactoryBuilder<typename ServerRPC::Request, RPCHandler>,
detail::has_initial_request(ServerRPC::TYPE)>;

template <class ServerRPC, class RPCHandler, class... Args>
using RequestMessageFactoryServerRPCStarter =
detail::RequestMessageFactoryServerRPCMixinT<detail::ServerRPCStarterT<ServerRPC, Args...>, ServerRPC, RPCHandler>;

template <class Starter, class Handler, class RPC, class... Args>
using RPCHandlerInvokeResultT =
decltype(std::declval<Starter>().invoke(std::declval<Handler>(), std::declval<RPC>(), std::declval<Args>()...));
Expand Down
2 changes: 1 addition & 1 deletion src/agrpc/register_awaitable_rpc_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ auto register_awaitable_rpc_handler(const typename ServerRPC::executor_type& exe
detail::ServerRPCServiceT<ServerRPC>& service, RPCHandler rpc_handler,
CompletionToken&& token = CompletionToken{})
{
using Starter = detail::ServerRPCStarterT<ServerRPC>;
using Starter = detail::RequestMessageFactoryServerRPCStarter<ServerRPC, RPCHandler>;
using CoroutineTraits = detail::CoroutineTraits<detail::RPCHandlerInvokeResultT<Starter&, RPCHandler&, ServerRPC&>>;
static_assert(sizeof(CoroutineTraits) > 0,
"Rpc handler must return an asio::awaitable and take ServerRPC& and, for server-streaming and unary "
Expand Down
2 changes: 1 addition & 1 deletion src/agrpc/register_callback_rpc_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ auto register_callback_rpc_handler(const typename ServerRPC::executor_type& exec
detail::ServerRPCServiceT<ServerRPC>& service, RPCHandler rpc_handler,
CompletionToken&& token = CompletionToken{})
{
using Starter = detail::ServerRPCStarterT<ServerRPC>;
using Starter = detail::RequestMessageFactoryServerRPCStarter<ServerRPC, RPCHandler>;
using CheckRPCHandlerTakesServerRPCPtrAsArg [[maybe_unused]] =
detail::RPCHandlerInvokeResultT<Starter&, RPCHandler&, typename ServerRPC::Ptr>;
return asio::async_initiate<CompletionToken, void(std::exception_ptr)>(
Expand Down
2 changes: 1 addition & 1 deletion src/agrpc/register_sender_rpc_handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ template <class ServerRPC, class RPCHandler>
[[nodiscard]] detail::RPCHandlerSender<ServerRPC, RPCHandler> register_sender_rpc_handler(
agrpc::GrpcContext& grpc_context, detail::ServerRPCServiceT<ServerRPC>& service, RPCHandler rpc_handler)
{
using Starter = detail::ServerRPCStarterT<ServerRPC>;
using Starter = detail::RequestMessageFactoryServerRPCStarter<ServerRPC, RPCHandler>;
static_assert(detail::exec::is_sender_v<detail::RPCHandlerInvokeResultT<Starter&, RPCHandler&, ServerRPC&>>,
"Rpc handler must return a sender.");
return {grpc_context, service, static_cast<RPCHandler&&>(rpc_handler)};
Expand Down
4 changes: 2 additions & 2 deletions src/agrpc/server_rpc_ptr.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,12 @@ class ServerRPCPtr
/**
* @brief Get client's initial request message
*/
decltype(auto) request() noexcept { return (server_rpc_->request_); }
decltype(auto) request() noexcept { return *server_rpc_->request_; }

/**
* @brief Get client's initial request message (const overload)
*/
decltype(auto) request() const noexcept { return (server_rpc_->request_); }
decltype(auto) request() const noexcept { return *server_rpc_->request_; }

/**
* @brief Swap the contents of two ServerRPCPtr
Expand Down
Loading

0 comments on commit a0bbfa0

Please sign in to comment.