diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index c423644b..ef936dd6 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -559,7 +559,7 @@ impl<'a> ServiceGenerator<'a> { )| { quote! { #( #attrs )* - async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output; + fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> impl ::core::future::Future + ::core::marker::Send; } }, ); @@ -567,7 +567,7 @@ impl<'a> ServiceGenerator<'a> { let stub_doc = format!("The stub trait for service [`{service_ident}`]."); quote! { #( #attrs )* - #vis trait #service_ident: ::core::marker::Sized { + #vis trait #service_ident: ::core::marker::Sized + ::core::marker::Send { #( #rpc_fns )* /// Returns a serving function to use with @@ -578,11 +578,11 @@ impl<'a> ServiceGenerator<'a> { } #[doc = #stub_doc] - #vis trait #client_stub_ident: ::tarpc::client::stub::Stub { + #vis trait #client_stub_ident: ::tarpc::client::stub::SendStub { } impl #client_stub_ident for S - where S: ::tarpc::client::stub::Stub + where S: ::tarpc::client::stub::SendStub { } } @@ -616,7 +616,7 @@ impl<'a> ServiceGenerator<'a> { } = self; quote! { - impl ::tarpc::server::Serve for #server_ident + impl ::tarpc::server::SendServe for #server_ident where S: #service_ident { type Req = #request_ident; @@ -780,7 +780,7 @@ impl<'a> ServiceGenerator<'a> { quote! { impl #client_ident - where Stub: ::tarpc::client::stub::Stub< + where Stub: ::tarpc::client::stub::SendStub< Req = #request_ident, Resp = #response_ident> { diff --git a/tarpc/src/client/stub.rs b/tarpc/src/client/stub.rs index e7c11aa0..4f7b9486 100644 --- a/tarpc/src/client/stub.rs +++ b/tarpc/src/client/stub.rs @@ -1,9 +1,11 @@ //! Provides a Stub trait, implemented by types that can call remote services. +use std::future::Future; + use crate::{ client::{Channel, RpcError}, context, - server::Serve, + server::{SendServe, Serve}, RequestName, }; @@ -15,7 +17,6 @@ mod mock; /// A connection to a remote service. /// Calls the service with requests of type `Req` and receives responses of type `Resp`. -#[allow(async_fn_in_trait)] pub trait Stub { /// The service request type. type Req: RequestName; @@ -24,8 +25,28 @@ pub trait Stub { type Resp; /// Calls a remote service. - async fn call(&self, ctx: context::Context, request: Self::Req) - -> Result; + fn call( + &self, + ctx: context::Context, + request: Self::Req, + ) -> impl Future>; +} + +/// A connection to a remote service. +/// Calls the service with requests of type `Req` and receives responses of type `Resp`. +pub trait SendStub: Send { + /// The service request type. + type Req: RequestName; + + /// The service response type. + type Resp; + + /// Calls a remote service. + fn call( + &self, + ctx: context::Context, + request: Self::Req, + ) -> impl Future> + Send; } impl Stub for Channel @@ -40,6 +61,19 @@ where } } +impl SendStub for Channel +where + Req: RequestName + Send, + Resp: Send, +{ + type Req = Req; + type Resp = Resp; + + async fn call(&self, ctx: context::Context, request: Req) -> Result { + Self::call(self, ctx, request).await + } +} + impl Stub for S where S: Serve + Clone, @@ -50,3 +84,16 @@ where self.clone().serve(ctx, req).await.map_err(RpcError::Server) } } + +impl SendStub for S +where + S: SendServe + Clone + Sync, + S::Req: Send + Sync, + S::Resp: Send, +{ + type Req = S::Req; + type Resp = S::Resp; + async fn call(&self, ctx: context::Context, req: Self::Req) -> Result { + self.clone().serve(ctx, req).await.map_err(RpcError::Server) + } +} diff --git a/tarpc/src/client/stub/load_balance.rs b/tarpc/src/client/stub/load_balance.rs index e586b793..9315d2ec 100644 --- a/tarpc/src/client/stub/load_balance.rs +++ b/tarpc/src/client/stub/load_balance.rs @@ -28,6 +28,24 @@ mod round_robin { } } + impl stub::SendStub for RoundRobin + where + Stub: stub::SendStub + Send + Sync, + Stub::Req: Send, + { + type Req = Stub::Req; + type Resp = Stub::Resp; + + async fn call( + &self, + ctx: context::Context, + request: Self::Req, + ) -> Result { + let next = self.stubs.next(); + next.call(ctx, request).await + } + } + /// A Stub that load-balances across backing stubs by round robin. #[derive(Clone, Debug)] pub struct RoundRobin { diff --git a/tarpc/src/client/stub/mock.rs b/tarpc/src/client/stub/mock.rs index ae9ae9b2..44c0ecc0 100644 --- a/tarpc/src/client/stub/mock.rs +++ b/tarpc/src/client/stub/mock.rs @@ -1,5 +1,8 @@ use crate::{ - client::{stub::Stub, RpcError}, + client::{ + stub::{SendStub, Stub}, + RpcError, + }, context, RequestName, ServerError, }; use std::{collections::HashMap, hash::Hash, io}; @@ -42,3 +45,25 @@ where }) } } + +impl SendStub for Mock +where + Req: Eq + Hash + RequestName + Send + Sync, + Resp: Clone + Send + Sync, +{ + type Req = Req; + type Resp = Resp; + + async fn call(&self, _: context::Context, request: Self::Req) -> Result { + self.responses + .get(&request) + .cloned() + .map(Ok) + .unwrap_or_else(|| { + Err(RpcError::Server(ServerError { + kind: io::ErrorKind::NotFound, + detail: "mock (request, response) entry not found".into(), + })) + }) + } +} diff --git a/tarpc/src/client/stub/retry.rs b/tarpc/src/client/stub/retry.rs index 89b033bc..0779de17 100644 --- a/tarpc/src/client/stub/retry.rs +++ b/tarpc/src/client/stub/retry.rs @@ -33,6 +33,33 @@ where } } +impl stub::SendStub for Retry +where + Req: RequestName + Send + Sync, + Stub: stub::SendStub> + Send + Sync, + F: Fn(&Result, u32) -> bool + Send + Sync, +{ + type Req = Req; + type Resp = Stub::Resp; + + async fn call( + &self, + ctx: context::Context, + request: Self::Req, + ) -> Result { + let request = Arc::new(request); + for i in 1.. { + let result = self.stub.call(ctx, Arc::clone(&request)).await; + if (self.should_retry)(&result, i) { + tracing::trace!("Retrying on attempt {i}"); + continue; + } + return result; + } + unreachable!("Wow, that was a lot of attempts!"); + } +} + /// A Stub that retries requests based on response contents. /// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled. #[derive(Clone, Debug)] diff --git a/tarpc/src/server.rs b/tarpc/src/server.rs index d79d45c2..196fb596 100644 --- a/tarpc/src/server.rs +++ b/tarpc/src/server.rs @@ -67,7 +67,6 @@ impl Config { } /// Equivalent to a `FnOnce(Req) -> impl Future`. -#[allow(async_fn_in_trait)] pub trait Serve { /// Type of request. type Req: RequestName; @@ -76,7 +75,33 @@ pub trait Serve { type Resp; /// Responds to a single request. - async fn serve(self, ctx: context::Context, req: Self::Req) -> Result; + fn serve( + self, + ctx: context::Context, + req: Self::Req, + ) -> impl Future>; +} + +/// Equivalent to a `FnOnce(Req) -> impl Future`. +pub trait SendServe: Send { + /// Type of request. + type Req: RequestName; + /// Type of response. + type Resp; + /// Responds to a single request. + fn serve( + self, + ctx: context::Context, + req: Self::Req, + ) -> impl Future> + Send; +} + +impl Serve for S { + type Req = ::Req; + type Resp = ::Resp; + async fn serve(self, ctx: context::Context, req: Self::Req) -> Result { + ::serve(self, ctx, req).await + } } /// A Serve wrapper around a Fn. @@ -113,11 +138,11 @@ where } } -impl Serve for ServeFn +impl SendServe for ServeFn where - Req: RequestName, - F: FnOnce(context::Context, Req) -> Fut, - Fut: Future>, + Req: RequestName + Send, + F: FnOnce(context::Context, Req) -> Fut + Send, + Fut: Future> + Send, { type Req = Req; type Resp = Resp;