diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8582900..7247df8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -57,16 +57,20 @@ jobs: toolchain: ${{ matrix.channel }} targets: ${{ matrix.target.toolchain }} - uses: swatinem/rust-cache@v2 - - name: cargo test (all features) + - name: cargo test (workspace, all features) run: cargo test --locked --workspace --all-features --bins --tests --examples - - name: cargo test (default features) + - name: cargo test (workspace, default features) run: cargo test --locked --workspace --bins --tests --examples - - name: cargo test (no default features) + - name: cargo test (workspace, no default features) run: cargo test --locked --workspace --no-default-features --bins --tests --examples - - name: cargo check (feature message_spans) - run: cargo check --no-default-features --features message_spans - - name: cargo check (feature rpc) - run: cargo check --no-default-features --features rpc + - name: cargo check (irpc, no default features) + run: cargo check --locked --no-default-features --bins --tests --examples + - name: cargo check (irpc, feature derive) + run: cargo check --locked --no-default-features --features derive --bins --tests --examples + - name: cargo check (irpc, feature spans) + run: cargo check --locked --no-default-features --features spans --bins --tests --examples + - name: cargo check (irpc, feature rpc) + run: cargo check --locked --no-default-features --features rpc --bins --tests --examples test-release: runs-on: ${{ matrix.target.os }} diff --git a/Cargo.toml b/Cargo.toml index d67f499..d781192 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -65,10 +65,26 @@ rpc = ["dep:quinn", "dep:postcard", "dep:anyhow", "dep:smallvec", "dep:tracing", # add test utilities quinn_endpoint_setup = ["rpc", "dep:rustls", "dep:rcgen", "dep:anyhow", "dep:futures-buffered", "quinn/rustls-ring"] # pick up parent span when creating channel messages -message_spans = ["dep:tracing"] +spans = ["dep:tracing"] stream = ["dep:futures-util"] derive = ["dep:irpc-derive"] -default = ["rpc", "quinn_endpoint_setup", "message_spans", "stream", "derive"] +default = ["rpc", "quinn_endpoint_setup", "spans", "stream", "derive"] + +[[example]] +name = "derive" +required-features = ["rpc", "derive", "quinn_endpoint_setup"] + +[[example]] +name = "compute" +required-features = ["rpc", "derive", "quinn_endpoint_setup"] + +[[example]] +name = "local" +required-features = ["derive"] + +[[example]] +name = "storage" +required-features = ["rpc", "quinn_endpoint_setup"] [workspace] members = ["irpc-derive", "irpc-iroh"] @@ -84,7 +100,7 @@ unexpected_cfgs = { level = "warn", check-cfg = ["cfg(quicrpc_docsrs)"] } anyhow = { version = "1.0.98" } tokio = { version = "1.44", default-features = false } postcard = { version = "1.1.1", default-features = false } -serde = { version = "1", default-features = false } +serde = { version = "1", default-features = false, features = ["derive"] } tracing = { version = "0.1.41", default-features = false } n0-future = { version = "0.1.2", default-features = false } tracing-subscriber = { version = "0.3.19" } diff --git a/examples/compute.rs b/examples/compute.rs index 84a6598..6ab5a71 100644 --- a/examples/compute.rs +++ b/examples/compute.rs @@ -1,17 +1,16 @@ use std::{ io::{self, Write}, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, - sync::Arc, }; use anyhow::bail; use futures_buffered::BufferedStreamExt; use irpc::{ channel::{mpsc, oneshot}, - rpc::{listen, Handler}, + rpc::{listen, RemoteService}, rpc_requests, util::{make_client_endpoint, make_server_endpoint}, - Client, LocalSender, Request, Service, WithChannels, + Client, Request, WithChannels, }; use n0_future::{ stream::StreamExt, @@ -21,11 +20,19 @@ use serde::{Deserialize, Serialize}; use thousands::Separable; use tracing::trace; -// Define the ComputeService -#[derive(Debug, Clone, Copy)] -struct ComputeService; - -impl Service for ComputeService {} +// Define the protocol and message enums using the macro +#[rpc_requests(message = ComputeMessage)] +#[derive(Serialize, Deserialize, Debug)] +enum ComputeProtocol { + #[rpc(tx=oneshot::Sender)] + Sqr(Sqr), + #[rpc(rx=mpsc::Receiver, tx=oneshot::Sender)] + Sum(Sum), + #[rpc(tx=mpsc::Sender)] + Fibonacci(Fibonacci), + #[rpc(rx=mpsc::Receiver, tx=mpsc::Sender)] + Multiply(Multiply), +} // Define ComputeRequest sub-messages #[derive(Debug, Serialize, Deserialize)] @@ -55,20 +62,6 @@ enum ComputeRequest { Multiply(Multiply), } -// Define the protocol and message enums using the macro -#[rpc_requests(ComputeService, message = ComputeMessage)] -#[derive(Serialize, Deserialize)] -enum ComputeProtocol { - #[rpc(tx=oneshot::Sender)] - Sqr(Sqr), - #[rpc(rx=mpsc::Receiver, tx=oneshot::Sender)] - Sum(Sum), - #[rpc(tx=mpsc::Sender)] - Fibonacci(Fibonacci), - #[rpc(rx=mpsc::Receiver, tx=mpsc::Sender)] - Multiply(Multiply), -} - // The actor that processes requests struct ComputeActor { recv: tokio::sync::mpsc::Receiver, @@ -79,9 +72,8 @@ impl ComputeActor { let (tx, rx) = tokio::sync::mpsc::channel(128); let actor = Self { recv: rx }; n0_future::task::spawn(actor.run()); - let local = LocalSender::::from(tx); ComputeApi { - inner: local.into(), + inner: Client::local(tx), } } @@ -157,7 +149,7 @@ impl ComputeActor { // The API for interacting with the ComputeService #[derive(Clone)] struct ComputeApi { - inner: Client, + inner: Client, } impl ComputeApi { @@ -168,18 +160,10 @@ impl ComputeApi { } pub fn listen(&self, endpoint: quinn::Endpoint) -> anyhow::Result> { - let Some(local) = self.inner.local() else { + let Some(local) = self.inner.as_local() else { bail!("cannot listen on a remote service"); }; - let handler: Handler = Arc::new(move |msg, rx, tx| { - let local = local.clone(); - Box::pin(match msg { - ComputeProtocol::Sqr(msg) => local.send((msg, tx)), - ComputeProtocol::Sum(msg) => local.send((msg, tx, rx)), - ComputeProtocol::Fibonacci(msg) => local.send((msg, tx)), - ComputeProtocol::Multiply(msg) => local.send((msg, tx, rx)), - }) - }); + let handler = ComputeProtocol::remote_handler(local); Ok(AbortOnDropHandle::new(task::spawn(listen( endpoint, handler, )))) diff --git a/examples/derive.rs b/examples/derive.rs index 7928e10..88a579a 100644 --- a/examples/derive.rs +++ b/examples/derive.rs @@ -1,28 +1,21 @@ use std::{ collections::BTreeMap, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, - sync::Arc, }; use anyhow::{Context, Result}; use irpc::{ channel::{mpsc, oneshot}, - rpc::Handler, + rpc::RemoteService, rpc_requests, util::{make_client_endpoint, make_server_endpoint}, - Client, LocalSender, Service, WithChannels, + Client, WithChannels, }; // Import the macro use n0_future::task::{self, AbortOnDropHandle}; use serde::{Deserialize, Serialize}; use tracing::info; -/// A simple storage service, just to try it out -#[derive(Debug, Clone, Copy)] -struct StorageService; - -impl Service for StorageService {} - #[derive(Debug, Serialize, Deserialize)] struct Get { key: String, @@ -48,8 +41,8 @@ struct SetMany; // Use the macro to generate both the StorageProtocol and StorageMessage enums // plus implement Channels for each type -#[rpc_requests(StorageService, message = StorageMessage)] -#[derive(Serialize, Deserialize)] +#[rpc_requests(message = StorageMessage)] +#[derive(Serialize, Deserialize, Debug)] enum StorageProtocol { #[rpc(tx=oneshot::Sender>)] Get(Get), @@ -74,9 +67,8 @@ impl StorageActor { state: BTreeMap::new(), }; n0_future::task::spawn(actor.run()); - let local = LocalSender::::from(tx); StorageApi { - inner: local.into(), + inner: Client::local(tx), } } @@ -123,7 +115,7 @@ impl StorageActor { } struct StorageApi { - inner: Client, + inner: Client, } impl StorageApi { @@ -134,17 +126,14 @@ impl StorageApi { } pub fn listen(&self, endpoint: quinn::Endpoint) -> Result> { - let local = self.inner.local().context("cannot listen on remote API")?; - let handler: Handler = Arc::new(move |msg, rx, tx| { - let local = local.clone(); - Box::pin(match msg { - StorageProtocol::Get(msg) => local.send((msg, tx)), - StorageProtocol::Set(msg) => local.send((msg, tx)), - StorageProtocol::SetMany(msg) => local.send((msg, tx, rx)), - StorageProtocol::List(msg) => local.send((msg, tx)), - }) - }); - let join_handle = task::spawn(irpc::rpc::listen(endpoint, handler)); + let local = self + .inner + .as_local() + .context("cannot listen on remote API")?; + let join_handle = task::spawn(irpc::rpc::listen( + endpoint, + StorageProtocol::remote_handler(local), + )); Ok(AbortOnDropHandle::new(join_handle)) } diff --git a/examples/local.rs b/examples/local.rs new file mode 100644 index 0000000..d3b5b68 --- /dev/null +++ b/examples/local.rs @@ -0,0 +1,105 @@ +//! This demonstrates using irpc with the derive macro but without the rpc feature +//! for local-only use. Run with: +//! ``` +//! cargo run --example local --no-default-features --features derive +//! ``` + +use std::collections::BTreeMap; + +use irpc::{channel::oneshot, rpc_requests, Client, WithChannels}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +struct Get { + key: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct List; + +#[derive(Debug, Serialize, Deserialize)] +struct Set { + key: String, + value: String, +} + +impl From<(String, String)> for Set { + fn from((key, value): (String, String)) -> Self { + Self { key, value } + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct SetMany; + +#[rpc_requests(message = StorageMessage, no_rpc, no_spans)] +#[derive(Serialize, Deserialize, Debug)] +enum StorageProtocol { + #[rpc(tx=oneshot::Sender>)] + Get(Get), + #[rpc(tx=oneshot::Sender<()>)] + Set(Set), +} + +struct Actor { + recv: tokio::sync::mpsc::Receiver, + state: BTreeMap, +} + +impl Actor { + async fn run(mut self) { + while let Some(msg) = self.recv.recv().await { + self.handle(msg).await; + } + } + + async fn handle(&mut self, msg: StorageMessage) { + match msg { + StorageMessage::Get(get) => { + let WithChannels { tx, inner, .. } = get; + tx.send(self.state.get(&inner.key).cloned()).await.ok(); + } + StorageMessage::Set(set) => { + let WithChannels { tx, inner, .. } = set; + self.state.insert(inner.key, inner.value); + tx.send(()).await.ok(); + } + } + } +} + +struct StorageApi { + inner: Client, +} + +impl StorageApi { + pub fn spawn() -> StorageApi { + let (tx, rx) = tokio::sync::mpsc::channel(1); + let actor = Actor { + recv: rx, + state: BTreeMap::new(), + }; + n0_future::task::spawn(actor.run()); + StorageApi { + inner: Client::local(tx), + } + } + + pub async fn get(&self, key: String) -> irpc::Result> { + self.inner.rpc(Get { key }).await + } + + pub async fn set(&self, key: String, value: String) -> irpc::Result<()> { + self.inner.rpc(Set { key, value }).await + } +} + +#[tokio::main] +async fn main() -> irpc::Result<()> { + tracing_subscriber::fmt::init(); + let api = StorageApi::spawn(); + api.set("hello".to_string(), "world".to_string()).await?; + let value = api.get("hello".to_string()).await?; + println!("get: hello = {value:?}"); + Ok(()) +} diff --git a/examples/storage.rs b/examples/storage.rs index cd721d5..100a16a 100644 --- a/examples/storage.rs +++ b/examples/storage.rs @@ -1,32 +1,29 @@ use std::{ collections::BTreeMap, net::{Ipv4Addr, SocketAddr, SocketAddrV4}, - sync::Arc, }; use anyhow::bail; use irpc::{ channel::{mpsc, none::NoReceiver, oneshot}, - rpc::{listen, Handler}, + rpc::{listen, RemoteService}, util::{make_client_endpoint, make_server_endpoint}, - Channels, Client, LocalSender, Request, Service, WithChannels, + Channels, Client, Request, Service, WithChannels, }; use n0_future::task::{self, AbortOnDropHandle}; use serde::{Deserialize, Serialize}; use tracing::info; -/// A simple storage service, just to try it out -#[derive(Debug, Clone, Copy)] -struct StorageService; - -impl Service for StorageService {} +impl Service for StorageProtocol { + type Message = StorageMessage; +} #[derive(Debug, Serialize, Deserialize)] struct Get { key: String, } -impl Channels for Get { +impl Channels for Get { type Rx = NoReceiver; type Tx = oneshot::Sender>; } @@ -34,7 +31,7 @@ impl Channels for Get { #[derive(Debug, Serialize, Deserialize)] struct List; -impl Channels for List { +impl Channels for List { type Rx = NoReceiver; type Tx = mpsc::Sender; } @@ -45,12 +42,12 @@ struct Set { value: String, } -impl Channels for Set { +impl Channels for Set { type Rx = NoReceiver; type Tx = oneshot::Sender<()>; } -#[derive(derive_more::From, Serialize, Deserialize)] +#[derive(derive_more::From, Serialize, Deserialize, Debug)] enum StorageProtocol { Get(Get), Set(Set), @@ -59,9 +56,19 @@ enum StorageProtocol { #[derive(derive_more::From)] enum StorageMessage { - Get(WithChannels), - Set(WithChannels), - List(WithChannels), + Get(WithChannels), + Set(WithChannels), + List(WithChannels), +} + +impl RemoteService for StorageProtocol { + fn with_remote_channels(self, rx: quinn::RecvStream, tx: quinn::SendStream) -> Self::Message { + match self { + StorageProtocol::Get(msg) => WithChannels::from((msg, tx, rx)).into(), + StorageProtocol::Set(msg) => WithChannels::from((msg, tx, rx)).into(), + StorageProtocol::List(msg) => WithChannels::from((msg, tx, rx)).into(), + } + } } struct StorageActor { @@ -77,9 +84,8 @@ impl StorageActor { state: BTreeMap::new(), }; n0_future::task::spawn(actor.run()); - let local = LocalSender::::from(tx); StorageApi { - inner: local.into(), + inner: Client::local(tx), } } @@ -115,7 +121,7 @@ impl StorageActor { } } struct StorageApi { - inner: Client, + inner: Client, } impl StorageApi { @@ -126,17 +132,10 @@ impl StorageApi { } pub fn listen(&self, endpoint: quinn::Endpoint) -> anyhow::Result> { - let Some(local) = self.inner.local() else { + let Some(local) = self.inner.as_local() else { bail!("cannot listen on a remote service"); }; - let handler: Handler = Arc::new(move |msg, _rx, tx| { - let local = local.clone(); - Box::pin(match msg { - StorageProtocol::Get(msg) => local.send((msg, tx)), - StorageProtocol::Set(msg) => local.send((msg, tx)), - StorageProtocol::List(msg) => local.send((msg, tx)), - }) - }); + let handler = StorageProtocol::remote_handler(local); Ok(AbortOnDropHandle::new(task::spawn(listen( endpoint, handler, )))) diff --git a/irpc-derive/src/lib.rs b/irpc-derive/src/lib.rs index eaac093..a295048 100644 --- a/irpc-derive/src/lib.rs +++ b/irpc-derive/src/lib.rs @@ -29,7 +29,7 @@ fn generate_parent_span_impl(enum_name: &Ident, variant_names: &[&Ident]) -> Tok quote! { impl #enum_name { /// Get the parent span of the message - pub fn parent_span(&self) -> tracing::Span { + pub fn parent_span(&self) -> ::tracing::Span { let span = match self { #(#enum_name::#variant_names(inner) => inner.parent_span_opt()),* }; @@ -64,8 +64,8 @@ fn generate_channels_impl( Ok(res) } -/// Generates From implementations for cases with rpc attributes -fn generate_case_from_impls( +/// Generates From implementations for protocol enum variants. +fn generate_protocol_enum_from_impls( enum_name: &Ident, variants_with_attr: &[(Ident, Type)], ) -> TokenStream2 { @@ -90,7 +90,7 @@ fn generate_case_from_impls( impls } -/// Generate From implementations for message enum variants +/// Generates From implementations for message enum variants. fn generate_message_enum_from_impls( message_enum_name: &Ident, variants_with_attr: &[(Ident, Type)], @@ -117,6 +117,37 @@ fn generate_message_enum_from_impls( impls } +/// Generate Message::from_quic_streams impl +fn generate_remote_service_impl( + message_enum_name: &Ident, + proto_enum_name: &Ident, + variants_with_attr: &[(Ident, Type)], +) -> TokenStream2 { + let variants = variants_with_attr + .iter() + .map(|(variant_name, _inner_type)| { + quote! { + #proto_enum_name::#variant_name(msg) => { + #message_enum_name::from(::irpc::WithChannels::from((msg, tx, rx))) + } + } + }); + + quote! { + impl ::irpc::rpc::RemoteService for #proto_enum_name { + fn with_remote_channels( + self, + rx: ::irpc::rpc::quinn::RecvStream, + tx: ::irpc::rpc::quinn::SendStream + ) -> Self::Message { + match self { + #(#variants),* + } + } + } + } +} + /// Generate type aliases for WithChannels fn generate_type_aliases( variants: &[(Ident, Type)], @@ -145,75 +176,27 @@ fn generate_type_aliases( aliases } -/// Processes an RPC request enum and generates channel implementations. -/// -/// This macro takes a protocol enum where each variant represents a different RPC request type -/// and generates the necessary channel implementations for each request. -/// -/// # Macro Arguments -/// -/// * First positional argument (required): The service type that will handle these requests -/// * `message` (optional): Generate an extended enum wrapping each type in `WithChannels` -/// * `alias` (optional): Generate type aliases with the given suffix for each `WithChannels` -/// -/// # Variant Attributes -/// -/// Individual enum variants can be annotated with the `#[rpc(...)]` attribute to specify channel types: -/// -/// * `#[rpc(tx=SomeType)]`: Specify the transmitter/sender channel type (required) -/// * `#[rpc(tx=SomeType, rx=OtherType)]`: Also specify a receiver channel type (optional) -/// -/// If `rx` is not specified, it defaults to `NoReceiver`. -/// -/// # Examples -/// -/// Basic usage: -/// ``` -/// #[rpc_requests(ComputeService)] -/// enum ComputeProtocol { -/// #[rpc(tx=oneshot::Sender)] -/// Sqr(Sqr), -/// #[rpc(tx=oneshot::Sender)] -/// Sum(Sum), -/// } -/// ``` -/// -/// With a message enum: -/// ``` -/// #[rpc_requests(ComputeService, message = ComputeMessage)] -/// enum ComputeProtocol { -/// #[rpc(tx=oneshot::Sender)] -/// Sqr(Sqr), -/// #[rpc(tx=oneshot::Sender)] -/// Sum(Sum), -/// } -/// ``` -/// -/// With type aliases: -/// ``` -/// #[rpc_requests(ComputeService, alias = "Msg")] -/// enum ComputeProtocol { -/// #[rpc(tx=oneshot::Sender)] -/// Sqr(Sqr), // Generates type SqrMsg = WithChannels -/// #[rpc(tx=oneshot::Sender)] -/// Sum(Sum), // Generates type SumMsg = WithChannels -/// } -/// ``` #[proc_macro_attribute] pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { let mut input = parse_macro_input!(item as DeriveInput); let args = parse_macro_input!(attr as MacroArgs); - let service_name = args.service_name; - let message_enum_name = args.message_enum_name; - let alias_suffix = args.alias_suffix; - let enum_name = &input.ident; + let vis = &input.vis; let input_span = input.span(); + let cfg_feature_rpc = match args.rpc_feature.as_ref() { + None => quote!(), + Some(feature) => quote!(#[cfg(feature = #feature)]), + }; let data_enum = match &mut input.data { Data::Enum(data_enum) => data_enum, - _ => return error_tokens(input.span(), "RpcRequests can only be applied to enums"), + _ => { + return error_tokens( + input.span(), + "The rpc_requests macro can only be applied to enums", + ) + } }; // Collect trait implementations @@ -277,7 +260,7 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { Err(e) => return e.to_compile_error().into(), }; - match generate_channels_impl(args, &service_name, request_type, attr.span()) { + match generate_channels_impl(args, enum_name, request_type, attr.span()) { Ok(impls) => channel_impls.push(impls), Err(e) => return e.to_compile_error().into(), } @@ -285,24 +268,25 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { } // Generate From implementations for the original enum (only for variants with rpc attributes) - let original_from_impls = generate_case_from_impls(enum_name, &variants_with_attr); + let protocol_enum_from_impls = + generate_protocol_enum_from_impls(enum_name, &variants_with_attr); // Generate type aliases if requested - let type_aliases = if let Some(suffix) = alias_suffix { + let type_aliases = if let Some(suffix) = args.alias_suffix { // Use all variants for type aliases, not just those with rpc attributes - generate_type_aliases(&all_variants, &service_name, &suffix) + generate_type_aliases(&all_variants, enum_name, &suffix) } else { quote! {} }; // Generate the extended message enum if requested - let extended_enum_code = if let Some(message_enum_name) = message_enum_name { + let extended_enum_code = if let Some(message_enum_name) = args.message_enum_name.as_ref() { let message_variants = all_variants .iter() .map(|(variant_name, inner_type)| { quote! { #[allow(missing_docs)] - #variant_name(::irpc::WithChannels<#inner_type, #service_name>) + #variant_name(::irpc::WithChannels<#inner_type, #enum_name>) } }) .collect::>(); @@ -314,28 +298,47 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { let message_enum = quote! { #[allow(missing_docs)] #[derive(Debug)] - pub enum #message_enum_name { + #vis enum #message_enum_name { #(#message_variants),* } }; // Generate parent_span method - let parent_span_impl = generate_parent_span_impl(&message_enum_name, &variant_names); + let parent_span_impl = if !args.no_spans { + generate_parent_span_impl(message_enum_name, &variant_names) + } else { + quote! {} + }; // Generate From implementations for the message enum (only for variants with rpc attributes) - let message_from_impls = generate_message_enum_from_impls( - &message_enum_name, - &variants_with_attr, - &service_name, - ); + let message_from_impls = + generate_message_enum_from_impls(message_enum_name, &variants_with_attr, enum_name); + + let service_impl = quote! { + impl ::irpc::Service for #enum_name { + type Message = #message_enum_name; + } + }; + + let remote_service_impl = if !args.no_rpc { + let block = + generate_remote_service_impl(message_enum_name, enum_name, &variants_with_attr); + quote! { + #cfg_feature_rpc + #block + } + } else { + quote! {} + }; quote! { #message_enum + #service_impl + #remote_service_impl #parent_span_impl #message_from_impls } } else { - // If no message_enum_name is provided, don't generate the extended enum quote! {} }; @@ -347,7 +350,7 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { #(#channel_impls)* // From implementations for the original enum - #original_from_impls + #protocol_enum_from_impls // Type aliases for WithChannels #type_aliases @@ -361,34 +364,60 @@ pub fn rpc_requests(attr: TokenStream, item: TokenStream) -> TokenStream { // Parse arguments for the macro struct MacroArgs { - service_name: Ident, message_enum_name: Option, alias_suffix: Option, + rpc_feature: Option, + no_rpc: bool, + no_spans: bool, } impl Parse for MacroArgs { fn parse(input: ParseStream) -> syn::Result { - // First argument must be the service name (positional) - let service_name: Ident = input.parse()?; - // Initialize optional parameters let mut message_enum_name = None; let mut alias_suffix = None; + let mut rpc_feature = None; + let mut no_rpc = false; + let mut no_spans = false; - // Parse any additional named parameters - while input.peek(Token![,]) { - input.parse::()?; + // Parse names parameters. + loop { let param_name: Ident = input.parse()?; - input.parse::()?; match param_name.to_string().as_str() { "message" => { - message_enum_name = Some(input.parse()?); + input.parse::()?; + let ident: Ident = input.parse()?; + message_enum_name = Some(ident); } "alias" => { + input.parse::()?; let lit: LitStr = input.parse()?; alias_suffix = Some(lit.value()); } + "rpc_feature" => { + input.parse::()?; + if no_rpc { + return Err(syn::Error::new( + param_name.span(), + "rpc_feature is incompatible with no_rpc", + )); + } + let lit: LitStr = input.parse()?; + rpc_feature = Some(lit.value()); + } + "no_rpc" => { + if rpc_feature.is_some() { + return Err(syn::Error::new( + param_name.span(), + "rpc_feature is incompatible with no_rpc", + )); + } + no_rpc = true; + } + "no_spans" => { + no_spans = true; + } _ => { return Err(syn::Error::new( param_name.span(), @@ -396,12 +425,20 @@ impl Parse for MacroArgs { )); } } + + if input.peek(Token![,]) { + input.parse::()?; + } else { + break; + } } Ok(MacroArgs { - service_name, message_enum_name, alias_suffix, + no_rpc, + no_spans, + rpc_feature, }) } } diff --git a/irpc-iroh/examples/auth.rs b/irpc-iroh/examples/auth.rs index 22ae2ef..61245f4 100644 --- a/irpc-iroh/examples/auth.rs +++ b/irpc-iroh/examples/auth.rs @@ -68,13 +68,13 @@ mod storage { use anyhow::Result; use iroh::{ - endpoint::{Connection, RecvStream, SendStream}, + endpoint::Connection, protocol::{AcceptError, ProtocolHandler}, Endpoint, }; use irpc::{ channel::{mpsc, oneshot}, - Client, Service, WithChannels, + Client, WithChannels, }; // Import the macro use irpc_derive::rpc_requests; @@ -84,12 +84,6 @@ mod storage { const ALPN: &[u8] = b"storage-api/0"; - /// A simple storage service, just to try it out - #[derive(Debug, Clone, Copy)] - struct StorageService; - - impl Service for StorageService {} - #[derive(Debug, Serialize, Deserialize)] struct Auth { token: String, @@ -114,8 +108,8 @@ mod storage { // Use the macro to generate both the StorageProtocol and StorageMessage enums // plus implement Channels for each type - #[rpc_requests(StorageService, message = StorageMessage)] - #[derive(Serialize, Deserialize)] + #[rpc_requests(message = StorageMessage)] + #[derive(Serialize, Deserialize, Debug)] enum StorageProtocol { #[rpc(tx=oneshot::Sender>)] Auth(Auth), @@ -136,52 +130,35 @@ mod storage { } impl ProtocolHandler for StorageServer { - fn accept( - &self, - conn: Connection, - ) -> impl std::future::Future> + Send { - let this = self.clone(); - Box::pin(async move { - let mut authed = false; - while let Some((msg, rx, tx)) = read_request(&conn).await? { - let msg_with_channels = upcast_message(msg, rx, tx); - match msg_with_channels { - StorageMessage::Auth(msg) => { - let WithChannels { inner, tx, .. } = msg; - if authed { - conn.close(1u32.into(), b"invalid message"); - break; - } else if inner.token != this.auth_token { - conn.close(1u32.into(), b"permission denied"); - break; - } else { - authed = true; - tx.send(Ok(())).await.ok(); - } + async fn accept(&self, conn: Connection) -> Result<(), AcceptError> { + let mut authed = false; + while let Some(msg) = read_request::(&conn).await? { + match msg { + StorageMessage::Auth(msg) => { + let WithChannels { inner, tx, .. } = msg; + if authed { + conn.close(1u32.into(), b"invalid message"); + break; + } else if inner.token != self.auth_token { + conn.close(1u32.into(), b"permission denied"); + break; + } else { + authed = true; + tx.send(Ok(())).await.ok(); } - _ => { - if !authed { - conn.close(1u32.into(), b"permission denied"); - break; - } else { - this.handle_authenticated(msg_with_channels).await; - } + } + msg => { + if !authed { + conn.close(1u32.into(), b"permission denied"); + break; + } else { + self.handle_authenticated(msg).await; } } } - conn.closed().await; - Ok(()) - }) - } - } - - fn upcast_message(msg: StorageProtocol, rx: RecvStream, tx: SendStream) -> StorageMessage { - match msg { - StorageProtocol::Auth(msg) => WithChannels::from((msg, tx, rx)).into(), - StorageProtocol::Get(msg) => WithChannels::from((msg, tx, rx)).into(), - StorageProtocol::Set(msg) => WithChannels::from((msg, tx, rx)).into(), - StorageProtocol::SetMany(msg) => WithChannels::from((msg, tx, rx)).into(), - StorageProtocol::List(msg) => WithChannels::from((msg, tx, rx)).into(), + } + conn.closed().await; + Ok(()) } } @@ -243,7 +220,7 @@ mod storage { } pub struct StorageClient { - inner: Client, + inner: Client, } impl StorageClient { diff --git a/irpc-iroh/examples/derive.rs b/irpc-iroh/examples/derive.rs index 365d587..bc8b10c 100644 --- a/irpc-iroh/examples/derive.rs +++ b/irpc-iroh/examples/derive.rs @@ -56,24 +56,19 @@ mod storage { //! //! The only `pub` item is [`StorageApi`], everything else is private. - use std::{collections::BTreeMap, sync::Arc}; + use std::collections::BTreeMap; use anyhow::{Context, Result}; use iroh::{protocol::ProtocolHandler, Endpoint}; use irpc::{ channel::{mpsc, oneshot}, - rpc::Handler, - rpc_requests, Client, LocalSender, Service, WithChannels, + rpc::RemoteService, + rpc_requests, Client, WithChannels, }; // Import the macro use irpc_iroh::{IrohProtocol, IrohRemoteConnection}; use serde::{Deserialize, Serialize}; use tracing::info; - /// A simple storage service, just to try it out - #[derive(Debug, Clone, Copy)] - struct StorageService; - - impl Service for StorageService {} #[derive(Debug, Serialize, Deserialize)] struct Get { @@ -91,8 +86,8 @@ mod storage { // Use the macro to generate both the StorageProtocol and StorageMessage enums // plus implement Channels for each type - #[rpc_requests(StorageService, message = StorageMessage)] - #[derive(Serialize, Deserialize)] + #[rpc_requests(message = StorageMessage)] + #[derive(Serialize, Deserialize, Debug)] enum StorageProtocol { #[rpc(tx=oneshot::Sender>)] Get(Get), @@ -115,9 +110,8 @@ mod storage { state: BTreeMap::new(), }; n0_future::task::spawn(actor.run()); - let local = LocalSender::::from(tx); StorageApi { - inner: local.into(), + inner: Client::local(tx), } } @@ -154,7 +148,7 @@ mod storage { } pub struct StorageApi { - inner: Client, + inner: Client, } impl StorageApi { @@ -174,17 +168,9 @@ mod storage { pub fn expose(&self) -> Result { let local = self .inner - .local() + .as_local() .context("can not listen on remote service")?; - let handler: Handler = Arc::new(move |msg, _rx, tx| { - let local = local.clone(); - Box::pin(match msg { - StorageProtocol::Get(msg) => local.send((msg, tx)), - StorageProtocol::Set(msg) => local.send((msg, tx)), - StorageProtocol::List(msg) => local.send((msg, tx)), - }) - }); - Ok(IrohProtocol::new(handler)) + Ok(IrohProtocol::new(StorageProtocol::remote_handler(local))) } pub async fn get(&self, key: String) -> irpc::Result> { diff --git a/irpc-iroh/src/lib.rs b/irpc-iroh/src/lib.rs index cfbc07f..cdb595c 100644 --- a/irpc-iroh/src/lib.rs +++ b/irpc-iroh/src/lib.rs @@ -8,7 +8,11 @@ use iroh::{ protocol::{AcceptError, ProtocolHandler}, }; use irpc::{ - rpc::{Handler, RemoteConnection}, + channel::RecvError, + rpc::{ + Handler, RemoteConnection, RemoteService, ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED, + MAX_MESSAGE_SIZE, + }, util::AsyncReadVarintExt, RequestError, }; @@ -131,13 +135,21 @@ pub async fn handle_connection( handler: Handler, ) -> io::Result<()> { loop { - let Some((msg, rx, tx)) = read_request(&connection).await? else { + let Some((msg, rx, tx)) = read_request_raw(&connection).await? else { return Ok(()); }; handler(msg, rx, tx).await?; } } +pub async fn read_request( + connection: &Connection, +) -> std::io::Result> { + Ok(read_request_raw::(connection) + .await? + .map(|(msg, rx, tx)| S::with_remote_channels(msg, rx, tx))) +} + /// Reads a single request from the connection. /// /// This accepts a bi-directional stream from the connection and reads and parses the request. @@ -145,7 +157,7 @@ pub async fn handle_connection( /// Returns the parsed request and the stream pair if reading and parsing the request succeeded. /// Returns None if the remote closed the connection with error code `0`. /// Returns an error for all other failure cases. -pub async fn read_request( +pub async fn read_request_raw( connection: &Connection, ) -> std::io::Result> { let (send, mut recv) = match connection.accept_bi().await { @@ -163,6 +175,13 @@ pub async fn read_request( .read_varint_u64() .await? .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?; + if size > MAX_MESSAGE_SIZE { + connection.close( + ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(), + b"request exceeded max message size", + ); + return Err(RecvError::MaxMessageSizeExceeded.into()); + } let mut buf = vec![0; size as usize]; recv.read_exact(&mut buf) .await diff --git a/src/lib.rs b/src/lib.rs index e049553..fbb89f9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -62,7 +62,7 @@ //! - `rpc`: Enable the rpc features. Enabled by default. //! By disabling this feature, all rpc related dependencies are removed. //! The remaining dependencies are just serde, tokio and tokio-util. -//! - `message_spans`: Enable tracing spans for messages. Enabled by default. +//! - `spans`: Enable tracing spans for messages. Enabled by default. //! This is useful even without rpc, to not lose tracing context when message //! passing. This is frequently done manually. This obviously requires //! a dependency on tracing. @@ -78,9 +78,98 @@ #![cfg_attr(quicrpc_docsrs, feature(doc_cfg))] use std::{fmt::Debug, future::Future, io, marker::PhantomData, ops::Deref, result}; +use channel::{mpsc, oneshot}; +/// Processes an RPC request enum and generates trait implementations for use with `irpc`. +/// +/// This attribute macro may be applied to an enum where each variant represents +/// a different RPC request type. Each variant of the enum must contain a single unnamed field +/// of a distinct type, otherwise compilation fails. +/// +/// Basic usage example: +/// ``` +/// use serde::{Serialize, Deserialize}; +/// use irpc::{rpc_requests, channel::{oneshot, mpsc}}; +/// +/// #[rpc_requests(message = ComputeMessage)] +/// #[derive(Debug, Serialize, Deserialize)] +/// enum ComputeProtocol { +/// /// Multiply two numbers, return the result over a oneshot channel. +/// #[rpc(tx=oneshot::Sender)] +/// Multiply(Multiply), +/// /// Sum all numbers received via the `rx` stream, +/// /// reply with the updating sum over the `tx` stream. +/// #[rpc(tx=mpsc::Sender, rx=mpsc::Receiver)] +/// Sum(Sum), +/// } +/// +/// #[derive(Debug, Serialize, Deserialize)] +/// struct Multiply(i64, i64); +/// +/// #[derive(Debug, Serialize, Deserialize)] +/// struct Sum; +/// ``` +/// +/// ## Generated code +/// +/// If no further arguments are set, the macro generates: +/// +/// * A [`Channels`] implementation for each request type (i.e. the type of the variant's +/// single unnamed field). +/// The `Tx` and `Rx` types are set to the types provided via the variant's `rpc` attribute. +/// * A [`From`] implementation to convert from each request type to the protocol enum. +/// +/// When the `message` argument is set, the macro will also create a message enum and implement the +/// [`Service`] and [`RemoteService`] traits for the protocol enum. This is recommended for the +/// typical use of the macro. +/// +/// ## Macro arguments +/// +/// * `message = ` *(optional but recommended)*: +/// * Generates an extended enum wrapping each type in [`WithChannels`]. +/// The attribute value is the name of the message enum type. +/// * Generates a [`Service`] implementation for the protocol enum, with the `Message` +/// type set to the message enum. +/// * Generates a [`RemoteService`] implementation for the protocol enum. +/// * `alias = ""` *(optional)*: Generate type aliases with the given suffix for each [`WithChannels`]. +/// * `rpc_feature = ""` *(optional)*: If set, the [`RemoteService`] implementation will be feature-flagged +/// with this feature. Set this if your crate only optionally enables the `rpc` feature +/// of [`irpc`]. +/// * `no_rpc` *(optional, no value)*: If set, no implementation of [`RemoteService`] will be generated and the generated +/// code works without the `rpc` feature of `irpc`. +/// * `no_spans` *(optional, no value)*: If set, the generated code works without the `spans` feature of `irpc`. +/// +/// ## Variant attributes +/// +/// Individual enum variants are annotated with the `#[rpc(...)]` attribute to specify channel types. +/// The `rpc` attribute contains a key-value list with these arguments: +/// +/// * `tx = SomeType` *(required)*: Set the kind of channel for sending responses from the server to the client. +/// Must be a `Sender` type from the [`crate::channel`] module. +/// * `rx = OtherType` *(optional)*: Set the kind of channel for receiving updates from the client at the server. +/// Must be a `Receiver` type from the [`crate::channel`] module. If `rx` is not set, +/// it defaults to [`crate::channel::none::NoReceiver`]. +/// +/// ## Examples +/// +/// With type aliases: +/// ```no_compile +/// #[rpc_requests(message = ComputeMessage, alias = "Msg")] +/// enum ComputeProtocol { +/// #[rpc(tx=oneshot::Sender)] +/// Sqr(Sqr), // Generates type SqrMsg = WithChannels +/// #[rpc(tx=mpsc::Sender)] +/// Sum(Sum), // Generates type SumMsg = WithChannels +/// } +/// ``` +/// +/// [`irpc`]: crate +/// [`RemoteService`]: rpc::RemoteService +/// [`WithChannels`]: WithChannels +/// [`Channels`]: Channels #[cfg(feature = "derive")] #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "derive")))] pub use irpc_derive::rpc_requests; + use sealed::Sealed; use serde::{de::DeserializeOwned, Serialize}; @@ -106,14 +195,21 @@ impl RpcMessage for T where { } -/// Marker trait for a service +/// Trait for a service /// -/// This is usually implemented by a zero-sized struct. -/// It has various bounds to make derives easier. +/// This is implemented on the protocol enum. +/// It is usually auto-implemented via the [`rpc_requests] macro. /// /// A service acts as a scope for defining the tx and rx channels for each /// message type, and provides some type safety when sending messages. -pub trait Service: Send + Sync + Debug + Clone + 'static {} +pub trait Service: Serialize + DeserializeOwned + Send + Sync + Debug + 'static { + /// Message enum for this protocol. + /// + /// This is expected to be an enum with identical variant names than the + /// protocol enum, but its single unit field is the [`WithChannels`] struct + /// that contains the inner request plus the `tx` and `rx` channels. + type Message: Send + Unpin + 'static; +} mod sealed { pub trait Sealed {} @@ -126,7 +222,7 @@ pub trait Sender: Debug + Sealed {} pub trait Receiver: Debug + Sealed {} /// Trait to specify channels for a message and service -pub trait Channels { +pub trait Channels: Send + 'static { /// The sender type, can be either mpsc, oneshot or none type Tx: Sender; /// The receiver type, can be either mpsc, oneshot or none @@ -452,8 +548,8 @@ pub mod channel { /// ## Cancellation safety /// /// If the future is dropped before completion, and if this is a remote sender, - /// then the sender will be closed and further sends will return an [`io::Error`] - /// with [`io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the + /// then the sender will be closed and further sends will return an [`SendError::Io`] + /// with [`std::io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the /// future until completion if you want to reuse the sender or any clone afterwards. pub async fn send(&self, value: T) -> std::result::Result<(), SendError> { match self { @@ -482,8 +578,8 @@ pub mod channel { /// ## Cancellation safety /// /// If the future is dropped before completion, and if this is a remote sender, - /// then the sender will be closed and further sends will return an [`io::Error`] - /// with [`io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the + /// then the sender will be closed and further sends will return an [`SendError::Io`] + /// with [`std::io::ErrorKind::BrokenPipe`]. Therefore, make sure to always poll the /// future until completion if you want to reuse the sender or any clone afterwards. pub async fn try_send(&mut self, value: T) -> std::result::Result { match self { @@ -596,6 +692,8 @@ pub mod channel { #[error("receiver closed")] ReceiverClosed, /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]). + /// + /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE #[error("maximum message size exceeded")] MaxMessageSizeExceeded, /// The underlying io error. This can occur for remote communication, @@ -625,7 +723,9 @@ pub mod channel { /// for local communication. #[error("sender closed")] SenderClosed, - /// The message exceeded the maximum allowed message size [`MAX_MESSAGE_SIZE`]. + /// The message exceeded the maximum allowed message size (see [`MAX_MESSAGE_SIZE`]). + /// + /// [`MAX_MESSAGE_SIZE`]: crate::rpc::MAX_MESSAGE_SIZE #[error("maximum message size exceeded")] MaxMessageSizeExceeded, /// An io error occurred. This can occur for remote communication, @@ -652,7 +752,7 @@ pub mod channel { /// The channel kind for rx and tx is defined by implementing the `Channels` /// trait, either manually or using a macro. /// -/// When the `message_spans` feature is enabled, this also includes a tracing +/// When the `spans` feature is enabled, this also includes a tracing /// span to carry the tracing context during message passing. pub struct WithChannels, S: Service> { /// The inner message. @@ -662,8 +762,8 @@ pub struct WithChannels, S: Service> { /// The request channel to receive the request from. Can be set to [`NoReceiver`] if not needed. pub rx: >::Rx, /// The current span where the full message was created. - #[cfg(feature = "message_spans")] - #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "message_spans")))] + #[cfg(feature = "spans")] + #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))] pub span: tracing::Span, } @@ -679,7 +779,7 @@ impl + Debug, S: Service> Debug for WithChannels { impl, S: Service> WithChannels { /// Get the parent span - #[cfg(feature = "message_spans")] + #[cfg(feature = "spans")] pub fn parent_span_opt(&self) -> Option<&tracing::Span> { Some(&self.span) } @@ -700,8 +800,8 @@ where inner, tx: tx.into(), rx: rx.into(), - #[cfg(feature = "message_spans")] - #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "message_spans")))] + #[cfg(feature = "spans")] + #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))] span: tracing::Span::current(), } } @@ -722,8 +822,8 @@ where inner, tx: tx.into(), rx: NoReceiver, - #[cfg(feature = "message_spans")] - #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "message_spans")))] + #[cfg(feature = "spans")] + #[cfg_attr(quicrpc_docsrs, doc(cfg(feature = "spans")))] span: tracing::Span::current(), } } @@ -758,27 +858,27 @@ impl, S: Service> Deref for WithChannels { /// The service type `S` provides a scope for the protocol messages. It exists /// so you can use the same message with multiple services. #[derive(Debug)] -pub struct Client(ClientInner, PhantomData<(R, S)>); +pub struct Client(ClientInner, PhantomData); -impl Clone for Client { +impl Clone for Client { fn clone(&self) -> Self { Self(self.0.clone(), PhantomData) } } -impl From> for Client { - fn from(tx: LocalSender) -> Self { +impl From> for Client { + fn from(tx: LocalSender) -> Self { Self(ClientInner::Local(tx.0), PhantomData) } } -impl From> for Client { - fn from(tx: tokio::sync::mpsc::Sender) -> Self { +impl From> for Client { + fn from(tx: tokio::sync::mpsc::Sender) -> Self { LocalSender::from(tx).into() } } -impl Client { +impl Client { /// Create a new client to a remote service using the given quinn `endpoint` /// and a socket `addr` of the remote service. #[cfg(feature = "rpc")] @@ -794,9 +894,14 @@ impl Client { Self(ClientInner::Remote(Box::new(remote)), PhantomData) } + /// Creates a new client from a `tokio::sync::mpsc::Sender`. + pub fn local(tx: tokio::sync::mpsc::Sender) -> Self { + tx.into() + } + /// Get the local sender. This is useful if you don't care about remote /// requests. - pub fn local(&self) -> Option> { + pub fn as_local(&self) -> Option> { match &self.0 { ClientInner::Local(tx) => Some(tx.clone().into()), ClientInner::Remote(..) => None, @@ -818,13 +923,8 @@ impl Client { pub fn request( &self, ) -> impl Future< - Output = result::Result, rpc::RemoteSender>, RequestError>, - > + 'static - where - S: Service, - M: Send + Sync + 'static, - R: 'static, - { + Output = result::Result, rpc::RemoteSender>, RequestError>, + > + 'static { #[cfg(feature = "rpc")] { let cloned = match &self.0 { @@ -854,17 +954,16 @@ impl Client { /// Performs a request for which the server returns a oneshot receiver. pub fn rpc(&self, msg: Req) -> impl Future> + Send + 'static where - S: Service, - M: From> + Send + Sync + Unpin + 'static, - R: From + Serialize + Send + Sync + 'static, - Req: Channels, Rx = NoReceiver> + Send + 'static, + S: From, + S::Message: From>, + Req: Channels, Rx = NoReceiver>, Res: RpcMessage, { let request = self.request(); async move { - let recv: channel::oneshot::Receiver = match request.await? { + let recv: oneshot::Receiver = match request.await? { Request::Local(request) => { - let (tx, rx) = channel::oneshot::channel(); + let (tx, rx) = oneshot::channel(); request.send((msg, tx)).await?; rx } @@ -886,19 +985,18 @@ impl Client { &self, msg: Req, local_response_cap: usize, - ) -> impl Future>> + Send + 'static + ) -> impl Future>> + Send + 'static where - S: Service, - M: From> + Send + Sync + Unpin + 'static, - R: From + Serialize + Send + Sync + 'static, - Req: Channels, Rx = NoReceiver> + Send + 'static, + S: From, + S::Message: From>, + Req: Channels, Rx = NoReceiver>, Res: RpcMessage, { let request = self.request(); async move { - let recv: channel::mpsc::Receiver = match request.await? { + let recv: mpsc::Receiver = match request.await? { Request::Local(request) => { - let (tx, rx) = channel::mpsc::channel(local_response_cap); + let (tx, rx) = mpsc::channel(local_response_cap); request.send((msg, tx)).await?; rx } @@ -919,40 +1017,32 @@ impl Client { &self, msg: Req, local_update_cap: usize, - ) -> impl Future< - Output = Result<( - channel::mpsc::Sender, - channel::oneshot::Receiver, - )>, - > + ) -> impl Future, oneshot::Receiver)>> where - S: Service, - M: From> + Send + Sync + Unpin + 'static, - R: From + Serialize + 'static, - Req: Channels, Rx = channel::mpsc::Receiver>, + S: From, + S::Message: From>, + Req: Channels, Rx = mpsc::Receiver>, Update: RpcMessage, Res: RpcMessage, { let request = self.request(); async move { - let (update_tx, res_rx): ( - channel::mpsc::Sender, - channel::oneshot::Receiver, - ) = match request.await? { - Request::Local(request) => { - let (req_tx, req_rx) = channel::mpsc::channel(local_update_cap); - let (res_tx, res_rx) = channel::oneshot::channel(); - request.send((msg, res_tx, req_rx)).await?; - (req_tx, res_rx) - } - #[cfg(not(feature = "rpc"))] - Request::Remote(_request) => unreachable!(), - #[cfg(feature = "rpc")] - Request::Remote(request) => { - let (tx, rx) = request.write(msg).await?; - (tx.into(), rx.into()) - } - }; + let (update_tx, res_rx): (mpsc::Sender, oneshot::Receiver) = + match request.await? { + Request::Local(request) => { + let (req_tx, req_rx) = mpsc::channel(local_update_cap); + let (res_tx, res_rx) = oneshot::channel(); + request.send((msg, res_tx, req_rx)).await?; + (req_tx, res_rx) + } + #[cfg(not(feature = "rpc"))] + Request::Remote(_request) => unreachable!(), + #[cfg(feature = "rpc")] + Request::Remote(request) => { + let (tx, rx) = request.write(msg).await?; + (tx.into(), rx.into()) + } + }; Ok((update_tx, res_rx)) } } @@ -963,26 +1053,21 @@ impl Client { msg: Req, local_update_cap: usize, local_response_cap: usize, - ) -> impl Future, channel::mpsc::Receiver)>> - + Send - + 'static + ) -> impl Future, mpsc::Receiver)>> + Send + 'static where - S: Service, - M: From> + Send + Sync + Unpin + 'static, - R: From + Serialize + Send + 'static, - Req: Channels, Rx = channel::mpsc::Receiver> - + Send - + 'static, + S: From, + S::Message: From>, + Req: Channels, Rx = mpsc::Receiver>, Update: RpcMessage, Res: RpcMessage, { let request = self.request(); async move { - let (update_tx, res_rx): (channel::mpsc::Sender, channel::mpsc::Receiver) = + let (update_tx, res_rx): (mpsc::Sender, mpsc::Receiver) = match request.await? { Request::Local(request) => { - let (update_tx, update_rx) = channel::mpsc::channel(local_update_cap); - let (res_tx, res_rx) = channel::mpsc::channel(local_response_cap); + let (update_tx, update_rx) = mpsc::channel(local_update_cap); + let (res_tx, res_rx) = mpsc::channel(local_response_cap); request.send((msg, res_tx, update_rx)).await?; (update_tx, res_rx) } @@ -1093,23 +1178,23 @@ impl From for io::Error { /// [`WithChannels`]. #[derive(Debug)] #[repr(transparent)] -pub struct LocalSender(tokio::sync::mpsc::Sender, std::marker::PhantomData); +pub struct LocalSender(tokio::sync::mpsc::Sender); -impl Clone for LocalSender { +impl Clone for LocalSender { fn clone(&self) -> Self { - Self(self.0.clone(), PhantomData) + Self(self.0.clone()) } } -impl From> for LocalSender { - fn from(tx: tokio::sync::mpsc::Sender) -> Self { - Self(tx, PhantomData) +impl From> for LocalSender { + fn from(tx: tokio::sync::mpsc::Sender) -> Self { + Self(tx) } } #[cfg(not(feature = "rpc"))] pub mod rpc { - pub struct RemoteSender(std::marker::PhantomData<(R, S)>); + pub struct RemoteSender(std::marker::PhantomData); } #[cfg(feature = "rpc")] @@ -1122,7 +1207,7 @@ pub mod rpc { use n0_future::{future::Boxed as BoxFuture, task::JoinSet}; use quinn::ConnectionError; - use serde::{de::DeserializeOwned, Serialize}; + use serde::de::DeserializeOwned; use smallvec::SmallVec; use tracing::{trace, trace_span, warn, Instrument}; @@ -1133,17 +1218,23 @@ pub mod rpc { oneshot, RecvError, SendError, }, util::{now_or_never, AsyncReadVarintExt, WriteVarintExt}, - RequestError, RpcMessage, + LocalSender, RequestError, RpcMessage, Service, }; + /// This is used by irpc-derive to refer to quinn types (SendStream and RecvStream) + /// to make generated code work for users without having to depend on quinn directly + /// (i.e. when using iroh). + #[doc(hidden)] + pub use quinn; + /// Default max message size (16 MiB). - const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16; + pub const MAX_MESSAGE_SIZE: u64 = 1024 * 1024 * 16; /// Error code on streams if the max message size was exceeded. - const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1; + pub const ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED: u32 = 1; /// Error code on streams if the sender tried to send an message that could not be postcard serialized. - const ERROR_CODE_INVALID_POSTCARD: u32 = 2; + pub const ERROR_CODE_INVALID_POSTCARD: u32 = 2; /// Error that can occur when writing the initial message when doing a /// cross-process RPC. @@ -1284,24 +1375,21 @@ pub mod rpc { /// A connection to a remote service that can be used to send the initial message. #[derive(Debug)] - pub struct RemoteSender( + pub struct RemoteSender( quinn::SendStream, quinn::RecvStream, - std::marker::PhantomData<(R, S)>, + std::marker::PhantomData, ); - impl RemoteSender { + impl RemoteSender { pub fn new(send: quinn::SendStream, recv: quinn::RecvStream) -> Self { Self(send, recv, PhantomData) } pub async fn write( self, - msg: impl Into, - ) -> std::result::Result<(quinn::SendStream, quinn::RecvStream), WriteError> - where - R: Serialize, - { + msg: impl Into, + ) -> std::result::Result<(quinn::SendStream, quinn::RecvStream), WriteError> { let RemoteSender(mut send, recv, _) = self; let msg = msg.into(); if postcard::experimental::serialized_size(&msg)? as u64 > MAX_MESSAGE_SIZE { @@ -1604,6 +1692,28 @@ pub mod rpc { + 'static, >; + /// Extension trait to [`Service`] to create a [`Service::Message`] from a [`Service`] + /// and a pair of QUIC streams. + /// + /// This trait is auto-implemented when using the [`crate::rpc_requests`] macro. + pub trait RemoteService: Service + Sized { + /// Returns the message enum for this request by combining `self` (the protocol enum) + /// with a pair of QUIC streams for `tx` and `rx` channels. + fn with_remote_channels( + self, + rx: quinn::RecvStream, + tx: quinn::SendStream, + ) -> Self::Message; + + /// Creates a [`Handler`] that forwards all messages to a [`LocalSender`]. + fn remote_handler(local_sender: LocalSender) -> Handler { + Arc::new(move |msg, rx, tx| { + let msg = Self::with_remote_channels(msg, rx, tx); + Box::pin(local_sender.send_raw(msg)) + }) + } + } + /// Utility function to listen for incoming connections and handle them with the provided handler pub async fn listen( endpoint: quinn::Endpoint, @@ -1621,39 +1731,79 @@ pub mod rpc { return io::Result::Ok(()); } }; - loop { - let (send, mut recv) = match connection.accept_bi().await { - Ok((s, r)) => (s, r), - Err(ConnectionError::ApplicationClosed(cause)) - if cause.error_code.into_inner() == 0 => - { - trace!("remote side closed connection {cause:?}"); - return Ok(()); - } - Err(cause) => { - warn!("failed to accept bi stream {cause:?}"); - return Err(cause.into()); - } - }; - let size = recv.read_varint_u64().await?.ok_or_else(|| { - io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size") - })?; - let mut buf = vec![0; size as usize]; - recv.read_exact(&mut buf) - .await - .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?; - let msg: R = postcard::from_bytes(&buf) - .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; - let rx = recv; - let tx = send; - handler(msg, rx, tx).await?; - } + handle_connection(connection, handler).await }; let span = trace_span!("rpc", id = request_id); tasks.spawn(fut.instrument(span)); request_id += 1; } } + + /// Handles a quic connection with the provided `handler`. + pub async fn handle_connection( + connection: quinn::Connection, + handler: Handler, + ) -> io::Result<()> { + loop { + let Some((msg, rx, tx)) = read_request_raw(&connection).await? else { + return Ok(()); + }; + handler(msg, rx, tx).await?; + } + } + + pub async fn read_request( + connection: &quinn::Connection, + ) -> std::io::Result> { + Ok(read_request_raw::(connection) + .await? + .map(|(msg, rx, tx)| S::with_remote_channels(msg, rx, tx))) + } + + /// Reads a single request from the connection. + /// + /// This accepts a bi-directional stream from the connection and reads and parses the request. + /// + /// Returns the parsed request and the stream pair if reading and parsing the request succeeded. + /// Returns None if the remote closed the connection with error code `0`. + /// Returns an error for all other failure cases. + pub async fn read_request_raw( + connection: &quinn::Connection, + ) -> std::io::Result> { + let (send, mut recv) = match connection.accept_bi().await { + Ok((s, r)) => (s, r), + Err(ConnectionError::ApplicationClosed(cause)) + if cause.error_code.into_inner() == 0 => + { + trace!("remote side closed connection {cause:?}"); + return Ok(None); + } + Err(cause) => { + warn!("failed to accept bi stream {cause:?}"); + return Err(cause.into()); + } + }; + let size = recv + .read_varint_u64() + .await? + .ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?; + if size > MAX_MESSAGE_SIZE { + connection.close( + ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED.into(), + b"request exceeded max message size", + ); + return Err(RecvError::MaxMessageSizeExceeded.into()); + } + let mut buf = vec![0; size as usize]; + recv.read_exact(&mut buf) + .await + .map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?; + let msg: R = postcard::from_bytes(&buf) + .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?; + let rx = recv; + let tx = send; + Ok(Some((msg, rx, tx))) + } } /// A request to a service. This can be either local or remote. @@ -1665,19 +1815,19 @@ pub enum Request { Remote(R), } -impl LocalSender { +impl LocalSender { /// Send a message to the service - pub fn send(&self, value: impl Into>) -> SendFut + pub fn send(&self, value: impl Into>) -> SendFut where T: Channels, - M: From>, + S::Message: From>, { - let value: M = value.into().into(); + let value: S::Message = value.into().into(); SendFut::new(self.0.clone(), value) } /// Send a message to the service without the type conversion magic - pub fn send_raw(&self, value: M) -> SendFut { + pub fn send_raw(&self, value: S::Message) -> SendFut { SendFut::new(self.0.clone(), value) } } diff --git a/tests/common.rs b/tests/common.rs index cf89a74..bdc3021 100644 --- a/tests/common.rs +++ b/tests/common.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "quinn_endpoint_setup")] + use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; use irpc::util::{make_client_endpoint, make_server_endpoint}; diff --git a/tests/derive.rs b/tests/derive.rs index 3e0122f..7173b15 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "derive")] + use irpc::{ channel::{none::NoSender, oneshot}, rpc_requests, @@ -36,7 +38,7 @@ fn derive_simple() { #[derive(Debug, Serialize, Deserialize)] struct Response4; - #[rpc_requests(Service, message = RequestWithChannels)] + #[rpc_requests(message = RequestWithChannels, no_rpc, no_spans)] #[derive(Debug, Serialize, Deserialize)] enum Request { #[rpc(tx=oneshot::Sender<()>)] @@ -48,11 +50,6 @@ fn derive_simple() { #[rpc(tx=NoSender)] ClientStreaming(ClientStreamingRequest), } - - #[derive(Debug, Clone)] - struct Service; - - impl irpc::Service for Service {} } /// Use diff --git a/tests/mpsc_channel.rs b/tests/mpsc_channel.rs index e397b37..d3982b6 100644 --- a/tests/mpsc_channel.rs +++ b/tests/mpsc_channel.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "quinn_endpoint_setup")] + use std::{ io::{self, ErrorKind}, time::Duration, diff --git a/tests/oneshot_channel.rs b/tests/oneshot_channel.rs index 2caaa51..922edbc 100644 --- a/tests/oneshot_channel.rs +++ b/tests/oneshot_channel.rs @@ -1,3 +1,5 @@ +#![cfg(feature = "quinn_endpoint_setup")] + use std::io::{self, ErrorKind}; use irpc::{