Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement an easier way to get Send Serve and Stub #480

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions plugins/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,15 @@ impl<'a> ServiceGenerator<'a> {
)| {
quote! {
#( #attrs )*
async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output;
fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> impl ::core::future::Future<Output = #output> + ::core::marker::Send;
}
},
);

let stub_doc = format!("The stub trait for service [`{service_ident}`].");
quote! {
#( #attrs )*
#vis trait #service_ident: ::core::marker::Sized {
#vis trait #service_ident: ::core::marker::Sized + ::core::marker::Send {
#( #rpc_fns )*

/// Returns a serving function to use with
Expand All @@ -578,11 +578,11 @@ impl<'a> ServiceGenerator<'a> {
}

#[doc = #stub_doc]
#vis trait #client_stub_ident: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
#vis trait #client_stub_ident: ::tarpc::client::stub::SendStub<Req = #request_ident, Resp = #response_ident> {
}

impl<S> #client_stub_ident for S
where S: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
where S: ::tarpc::client::stub::SendStub<Req = #request_ident, Resp = #response_ident>
{
}
}
Expand Down Expand Up @@ -616,7 +616,7 @@ impl<'a> ServiceGenerator<'a> {
} = self;

quote! {
impl<S> ::tarpc::server::Serve for #server_ident<S>
impl<S> ::tarpc::server::SendServe for #server_ident<S>
where S: #service_ident
{
type Req = #request_ident;
Expand Down Expand Up @@ -780,7 +780,7 @@ impl<'a> ServiceGenerator<'a> {

quote! {
impl<Stub> #client_ident<Stub>
where Stub: ::tarpc::client::stub::Stub<
where Stub: ::tarpc::client::stub::SendStub<
Req = #request_ident,
Resp = #response_ident>
{
Expand Down
55 changes: 51 additions & 4 deletions tarpc/src/client/stub.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
//! Provides a Stub trait, implemented by types that can call remote services.

use std::future::Future;

use crate::{
client::{Channel, RpcError},
context,
server::Serve,
server::{SendServe, Serve},
RequestName,
};

Expand All @@ -15,7 +17,6 @@ mod mock;

/// A connection to a remote service.
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
#[allow(async_fn_in_trait)]
pub trait Stub {
/// The service request type.
type Req: RequestName;
Expand All @@ -24,8 +25,28 @@ pub trait Stub {
type Resp;

/// Calls a remote service.
async fn call(&self, ctx: context::Context, request: Self::Req)
-> Result<Self::Resp, RpcError>;
fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> impl Future<Output = Result<Self::Resp, RpcError>>;
}

/// A connection to a remote service.
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
pub trait SendStub: Send {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to use return-type notation to avoid duplicating the trait definition? rust-lang/rust#129629

pub trait SendStub: Stub + Send where <Self as Stub>::call(..): Send {}

impl<S: Stub + Send> SendStub for S where <S as Stub>::call(..): Send {}

Benefits of doing it this way:

  • The implementations for Channel and Serve don't need to be duplicated
  • the trait method can still be an async fn.

I think the same trick can also be applied to Serve and the proc macro-generated service trait.

Of course, the downside is this is still an experimental feature. But I think I'd rather think about solving this in a forward-thinking way. We can add a feature gate for it here, as well, so that it's only enabled if compiled with a tarpc Cargo feature that enables unstable features.

/// The service request type.
type Req: RequestName;

/// The service response type.
type Resp;

/// Calls a remote service.
fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> impl Future<Output = Result<Self::Resp, RpcError>> + Send;
}

impl<Req, Resp> Stub for Channel<Req, Resp>
Expand All @@ -40,6 +61,19 @@ where
}
}

impl<Req, Resp> SendStub for Channel<Req, Resp>
where
Req: RequestName + Send,
Resp: Send,
{
type Req = Req;
type Resp = Resp;

async fn call(&self, ctx: context::Context, request: Req) -> Result<Self::Resp, RpcError> {
Self::call(self, ctx, request).await
}
}

impl<S> Stub for S
where
S: Serve + Clone,
Expand All @@ -50,3 +84,16 @@ where
self.clone().serve(ctx, req).await.map_err(RpcError::Server)
}
}

impl<S> SendStub for S
where
S: SendServe + Clone + Sync,
S::Req: Send + Sync,
S::Resp: Send,
{
type Req = S::Req;
type Resp = S::Resp;
async fn call(&self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, RpcError> {
self.clone().serve(ctx, req).await.map_err(RpcError::Server)
}
}
18 changes: 18 additions & 0 deletions tarpc/src/client/stub/load_balance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,24 @@ mod round_robin {
}
}

impl<Stub> stub::SendStub for RoundRobin<Stub>
where
Stub: stub::SendStub + Send + Sync,
Stub::Req: Send,
{
type Req = Stub::Req;
type Resp = Stub::Resp;

async fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> Result<Stub::Resp, RpcError> {
let next = self.stubs.next();
next.call(ctx, request).await
}
}

/// A Stub that load-balances across backing stubs by round robin.
#[derive(Clone, Debug)]
pub struct RoundRobin<Stub> {
Expand Down
27 changes: 26 additions & 1 deletion tarpc/src/client/stub/mock.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use crate::{
client::{stub::Stub, RpcError},
client::{
stub::{SendStub, Stub},
RpcError,
},
context, RequestName, ServerError,
};
use std::{collections::HashMap, hash::Hash, io};
Expand Down Expand Up @@ -42,3 +45,25 @@ where
})
}
}

impl<Req, Resp> SendStub for Mock<Req, Resp>
where
Req: Eq + Hash + RequestName + Send + Sync,
Resp: Clone + Send + Sync,
{
type Req = Req;
type Resp = Resp;

async fn call(&self, _: context::Context, request: Self::Req) -> Result<Resp, RpcError> {
self.responses
.get(&request)
.cloned()
.map(Ok)
.unwrap_or_else(|| {
Err(RpcError::Server(ServerError {
kind: io::ErrorKind::NotFound,
detail: "mock (request, response) entry not found".into(),
}))
})
}
}
27 changes: 27 additions & 0 deletions tarpc/src/client/stub/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,33 @@ where
}
}

impl<Stub, Req, F> stub::SendStub for Retry<F, Stub>
where
Req: RequestName + Send + Sync,
Stub: stub::SendStub<Req = Arc<Req>> + Send + Sync,
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool + Send + Sync,
{
type Req = Req;
type Resp = Stub::Resp;

async fn call(
&self,
ctx: context::Context,
request: Self::Req,
) -> Result<Stub::Resp, RpcError> {
let request = Arc::new(request);
for i in 1.. {
let result = self.stub.call(ctx, Arc::clone(&request)).await;
if (self.should_retry)(&result, i) {
tracing::trace!("Retrying on attempt {i}");
continue;
}
return result;
}
unreachable!("Wow, that was a lot of attempts!");
}
}

/// A Stub that retries requests based on response contents.
/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled.
#[derive(Clone, Debug)]
Expand Down
37 changes: 31 additions & 6 deletions tarpc/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ impl Config {
}

/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
#[allow(async_fn_in_trait)]
pub trait Serve {
/// Type of request.
type Req: RequestName;
Expand All @@ -76,7 +75,33 @@ pub trait Serve {
type Resp;

/// Responds to a single request.
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError>;
fn serve(
self,
ctx: context::Context,
req: Self::Req,
) -> impl Future<Output = Result<Self::Resp, ServerError>>;
}

/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
pub trait SendServe: Send {
/// Type of request.
type Req: RequestName;
/// Type of response.
type Resp;
/// Responds to a single request.
fn serve(
self,
ctx: context::Context,
req: Self::Req,
) -> impl Future<Output = Result<Self::Resp, ServerError>> + Send;
}

impl<S: SendServe> Serve for S {
type Req = <Self as SendServe>::Req;
type Resp = <Self as SendServe>::Resp;
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError> {
<Self as SendServe>::serve(self, ctx, req).await
}
}

/// A Serve wrapper around a Fn.
Expand Down Expand Up @@ -113,11 +138,11 @@ where
}
}

impl<Req, Resp, Fut, F> Serve for ServeFn<Req, Resp, F>
impl<Req, Resp, Fut, F> SendServe for ServeFn<Req, Resp, F>
where
Req: RequestName,
F: FnOnce(context::Context, Req) -> Fut,
Fut: Future<Output = Result<Resp, ServerError>>,
Req: RequestName + Send,
F: FnOnce(context::Context, Req) -> Fut + Send,
Fut: Future<Output = Result<Resp, ServerError>> + Send,
{
type Req = Req;
type Resp = Resp;
Expand Down
Loading