From 160b56184fcec1671d56d44e507054fc23a7671b Mon Sep 17 00:00:00 2001 From: Shane Murphy <88051904+ShaneMurphy2@users.noreply.github.com> Date: Sat, 9 Mar 2024 15:41:23 -0800 Subject: [PATCH] Support custom derives on Request and Response (#438) Replace the bespoke "derive_serde" with a more flexible "derive = [, , ...]" form. Deprecate the old "derive_serde" form, and emit deprecation warnings. --- plugins/src/lib.rs | 309 ++++++++++++++---- plugins/tests/service.rs | 53 ++- tarpc/tests/compile_fail.rs | 4 + .../no_explicit_serde_without_feature.rs | 9 + .../no_explicit_serde_without_feature.stderr | 11 + .../no_implicit_serde_without_feature.rs | 9 + .../no_implicit_serde_without_feature.stderr | 12 + tarpc/tests/compile_fail/serde1/deprecated.rs | 8 + .../compile_fail/serde1/deprecated.stderr | 15 + .../tests/compile_fail/serde1/incompatible.rs | 7 + .../compile_fail/serde1/incompatible.stderr | 7 + .../compile_fail/serde1/opt_out_serde.rs | 15 + .../compile_fail/serde1/opt_out_serde.stderr | 18 + 13 files changed, 411 insertions(+), 66 deletions(-) create mode 100644 tarpc/tests/compile_fail/no_serde1/no_explicit_serde_without_feature.rs create mode 100644 tarpc/tests/compile_fail/no_serde1/no_explicit_serde_without_feature.stderr create mode 100644 tarpc/tests/compile_fail/no_serde1/no_implicit_serde_without_feature.rs create mode 100644 tarpc/tests/compile_fail/no_serde1/no_implicit_serde_without_feature.stderr create mode 100644 tarpc/tests/compile_fail/serde1/deprecated.rs create mode 100644 tarpc/tests/compile_fail/serde1/deprecated.stderr create mode 100644 tarpc/tests/compile_fail/serde1/incompatible.rs create mode 100644 tarpc/tests/compile_fail/serde1/incompatible.stderr create mode 100644 tarpc/tests/compile_fail/serde1/opt_out_serde.rs create mode 100644 tarpc/tests/compile_fail/serde1/opt_out_serde.stderr diff --git a/plugins/src/lib.rs b/plugins/src/lib.rs index 8befe78f..c423644b 100644 --- a/plugins/src/lib.rs +++ b/plugins/src/lib.rs @@ -22,7 +22,7 @@ use syn::{ parse_macro_input, parse_quote, spanned::Spanned, token::Comma, - AttrStyle, Attribute, Expr, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, + AttrStyle, Attribute, Expr, FnArg, Ident, Lit, LitBool, MetaNameValue, Pat, PatType, Path, ReturnType, Token, Type, Visibility, }; @@ -141,14 +141,37 @@ impl Parse for RpcMethod { } } -// If `derive_serde` meta item is not present, defaults to cfg!(feature = "serde1"). -// `derive_serde` can only be true when serde1 is enabled. -struct DeriveSerde(bool); +#[derive(Default)] +struct DeriveMeta { + derive: Option, + warnings: Vec, +} + +impl DeriveMeta { + fn with_derives(mut self, new: Vec) -> Self { + match self.derive.as_mut() { + Some(Derive::Explicit(old)) => old.extend(new), + _ => self.derive = Some(Derive::Explicit(new)), + } + + self + } +} -impl Parse for DeriveSerde { +enum Derive { + Explicit(Vec), + Serde(bool), +} + +impl Parse for DeriveMeta { fn parse(input: ParseStream) -> syn::Result { - let mut result = Ok(None); + let mut result = Ok(DeriveMeta::default()); + + let mut derives = Vec::new(); let mut derive_serde = Vec::new(); + let mut has_derive_serde = false; + let mut has_explicit_derives = false; + let meta_items = input.parse_terminated(MetaNameValue::parse, Comma)?; for meta in meta_items { if meta.path.segments.len() != 1 { @@ -162,47 +185,117 @@ impl Parse for DeriveSerde { continue; } let segment = meta.path.segments.first().unwrap(); - if segment.ident != "derive_serde" { - extend_errors!( - result, - syn::Error::new( - meta.span(), - "tarpc::service does not support this meta item" - ) - ); - continue; - } - let Expr::Lit(expr_lit) = &meta.value else { - extend_errors!( - result, - syn::Error::new(meta.value.span(), "expected literal") - ); - continue; - }; - match expr_lit.lit { - Lit::Bool(LitBool { value: true, .. }) if cfg!(feature = "serde1") => { - result = result.and(Ok(Some(true))) - } - Lit::Bool(LitBool { value: true, .. }) => { + if segment.ident == "derive" { + has_explicit_derives = true; + let Expr::Array(ref array) = meta.value else { extend_errors!( result, syn::Error::new( meta.span(), - "To enable serde, first enable the `serde1` feature of tarpc" + "tarpc::service does not support this meta item" ) ); + continue; + }; + + let paths = array + .elems + .iter() + .filter_map(|e| { + if let Expr::Path(path) = e { + Some(path.path.clone()) + } else { + extend_errors!( + result, + syn::Error::new(e.span(), "Expected Path or Type") + ); + None + } + }) + .collect::>(); + + result = result.map(|d| d.with_derives(paths)); + derives.push(meta); + } else if segment.ident == "derive_serde" { + has_derive_serde = true; + let Expr::Lit(expr_lit) = &meta.value else { + extend_errors!( + result, + syn::Error::new(meta.value.span(), "expected literal") + ); + continue; + }; + match expr_lit.lit { + Lit::Bool(LitBool { value: true, .. }) if cfg!(feature = "serde1") => { + result = result.map(|d| DeriveMeta { + derive: Some(Derive::Serde(true)), + ..d + }) + } + Lit::Bool(LitBool { value: true, .. }) => { + extend_errors!( + result, + syn::Error::new( + meta.span(), + "To enable serde, first enable the `serde1` feature of tarpc" + ) + ); + } + Lit::Bool(LitBool { value: false, .. }) => { + result = result.map(|d| DeriveMeta { + derive: Some(Derive::Serde(false)), + ..d + }) + } + _ => extend_errors!( + result, + syn::Error::new( + expr_lit.lit.span(), + "`derive_serde` expects a value of type `bool`" + ) + ), } - Lit::Bool(LitBool { value: false, .. }) => result = result.and(Ok(Some(false))), - _ => extend_errors!( + derive_serde.push(meta); + } else { + extend_errors!( result, syn::Error::new( - expr_lit.lit.span(), - "`derive_serde` expects a value of type `bool`" + meta.span(), + "tarpc::service does not support this meta item" ) - ), + ); + continue; } - derive_serde.push(meta); } + + if has_derive_serde { + let deprecation_hack = quote! { + const _: () = { + #[deprecated( + note = "\nThe form `tarpc::service(derive_serde = true)` is deprecated.\ + \nUse `tarpc::service(derive = [Serialize, Deserialize])`." + )] + const DEPRECATED_SYNTAX: () = (); + let _ = DEPRECATED_SYNTAX; + }; + }; + + result = result.map(|mut d| { + d.warnings.push(deprecation_hack.to_token_stream()); + d + }); + } + + if has_explicit_derives & has_derive_serde { + extend_errors!( + result, + syn::Error::new( + input.span(), + "tarpc does not support `derive_serde` and `derive` at the same time" + ) + ); + } + if derive_serde.len() > 1 { for (i, derive_serde) in derive_serde.iter().enumerate() { extend_errors!( @@ -217,8 +310,20 @@ impl Parse for DeriveSerde { ); } } - let derive_serde = result?.unwrap_or(cfg!(feature = "serde1")); - Ok(Self(derive_serde)) + + if derives.len() > 1 { + for (i, derive) in derives.iter().enumerate() { + extend_errors!( + result, + syn::Error::new( + derive.span(), + format!("`derive` appears more than once (occurrence #{})", i + 1) + ) + ); + } + } + + result } } @@ -232,6 +337,7 @@ impl Parse for DeriveSerde { /// # struct Foo; /// ``` #[proc_macro_attribute] +#[cfg(feature = "serde1")] pub fn derive_serde(_attr: TokenStream, item: TokenStream) -> TokenStream { let mut gen: proc_macro2::TokenStream = quote! { #[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)] @@ -260,16 +366,54 @@ fn collect_cfg_attrs(rpcs: &[RpcMethod]) -> Vec> { .collect::>() } -/// Generates: -/// - service trait -/// - serve fn -/// - client stub struct -/// - new_stub client factory fn -/// - Request and Response enums -/// - ResponseFut Future +/// This macro generates the machinery used by both the client and server. +/// +/// Namely, it produces: +/// - a serve fn inside the trait +/// - client stub struct +/// - Request and Response enums +/// +/// # Example +/// +/// ```no_run +/// use tarpc::{client, transport, service, server::{self, Channel}, context::Context}; +/// +/// #[service] +/// pub trait Calculator { +/// async fn add(a: i32, b: i32) -> i32; +/// } +/// +/// // The request type looks like the following. +/// // Note, you don't have to interact with this type directly outside +/// // of testing, it is used by the client and server implementation +/// let req = CalculatorRequest::Add {a: 5, b: 7}; +/// +/// // This would be the associated response, again you don't ofent use this, +/// // it is only shown for educational purposes. +/// let resp = CalculatorResponse::Add(12); +/// +/// // This could be any transport. +/// let (client_side, server_side) = transport::channel::unbounded(); +/// +/// // A client can be made like so: +/// let client = CalculatorClient::new(client::Config::default(), client_side); +/// +/// // And a server like so: +/// #[derive(Clone)] +/// struct CalculatorServer; +/// impl Calculator for CalculatorServer { +/// async fn add(self, context: Context, a: i32, b: i32) -> i32 { +/// a + b +/// } +/// } +/// +/// // You would usually spawn on an async runtime. +/// let server = server::BaseChannel::with_defaults(server_side); +/// let _ = server.execute(CalculatorServer.serve()); +/// ``` #[proc_macro_attribute] pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { - let derive_serde = parse_macro_input!(attr as DeriveSerde); + let derive_meta = parse_macro_input!(attr as DeriveMeta); let unit_type: &Type = &parse_quote!(()); let Service { ref attrs, @@ -283,13 +427,41 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .map(|rpc| snake_to_camel(&rpc.ident.unraw().to_string())) .collect(); let args: &[&[PatType]] = &rpcs.iter().map(|rpc| &*rpc.args).collect::>(); - let derive_serialize = if derive_serde.0 { - Some( - quote! {#[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)] - #[serde(crate = "::tarpc::serde")]}, - ) - } else { - None + + let derives = match derive_meta.derive.as_ref() { + Some(Derive::Explicit(paths)) => { + if !paths.is_empty() { + Some(quote! { + #[derive( + #( + #paths + ),* + )] + }) + } else { + None + } + } + Some(Derive::Serde(serde)) => { + if *serde { + Some(quote! { + #[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)] + #[serde(crate = "::tarpc::serde")] + }) + } else { + None + } + } + None => { + if cfg!(feature = "serde1") { + Some(quote! { + #[derive(::tarpc::serde::Serialize, ::tarpc::serde::Deserialize)] + #[serde(crate = "::tarpc::serde")] + }) + } else { + None + } + } }; let methods = rpcs.iter().map(|rpc| &rpc.ident).collect::>(); @@ -316,7 +488,7 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { return_types: &rpcs .iter() .map(|rpc| match rpc.output { - ReturnType::Type(_, ref ty) => ty, + ReturnType::Type(_, ref ty) => ty.as_ref(), ReturnType::Default => unit_type, }) .collect::>(), @@ -329,7 +501,8 @@ pub fn service(attr: TokenStream, input: TokenStream) -> TokenStream { .zip(camel_case_fn_names.iter()) .map(|(rpc, name)| Ident::new(name, rpc.ident.span())) .collect::>(), - derive_serialize: derive_serialize.as_ref(), + derives: derives.as_ref(), + warnings: &derive_meta.warnings, } .into_token_stream() .into() @@ -355,7 +528,8 @@ struct ServiceGenerator<'a> { args: &'a [&'a [PatType]], return_types: &'a [&'a Type], arg_pats: &'a [Vec<&'a Pat>], - derive_serialize: Option<&'a TokenStream2>, + derives: Option<&'a TokenStream2>, + warnings: &'a [TokenStream2], } impl<'a> ServiceGenerator<'a> { @@ -378,11 +552,11 @@ impl<'a> ServiceGenerator<'a> { .zip(return_types.iter()) .map( |( - RpcMethod { - attrs, ident, args, .. - }, - output, - )| { + RpcMethod { + attrs, ident, args, .. + }, + output, + )| { quote! { #( #attrs )* async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output; @@ -470,7 +644,7 @@ impl<'a> ServiceGenerator<'a> { fn enum_request(&self) -> TokenStream2 { let &Self { - derive_serialize, + derives, vis, request_ident, camel_case_idents, @@ -484,7 +658,7 @@ impl<'a> ServiceGenerator<'a> { /// The request sent over the wire from the client to the server. #[allow(missing_docs)] #[derive(Debug)] - #derive_serialize + #derives #vis enum #request_ident { #( #( #method_cfgs )* @@ -508,7 +682,7 @@ impl<'a> ServiceGenerator<'a> { fn enum_response(&self) -> TokenStream2 { let &Self { - derive_serialize, + derives, vis, response_ident, camel_case_idents, @@ -520,7 +694,7 @@ impl<'a> ServiceGenerator<'a> { /// The response sent over the wire from the server to the client. #[allow(missing_docs)] #[derive(Debug)] - #derive_serialize + #derives #vis enum #response_ident { #( #camel_case_idents(#return_types) ),* } @@ -628,6 +802,10 @@ impl<'a> ServiceGenerator<'a> { } } } + + fn emit_warnings(&self) -> TokenStream2 { + self.warnings.iter().map(|w| w.to_token_stream()).collect() + } } impl<'a> ToTokens for ServiceGenerator<'a> { @@ -641,7 +819,8 @@ impl<'a> ToTokens for ServiceGenerator<'a> { self.struct_client(), self.impl_client_new(), self.impl_client_rpc_methods(), - ]) + self.emit_warnings(), + ]); } } diff --git a/plugins/tests/service.rs b/plugins/tests/service.rs index 2af2b1d1..fc4c2f26 100644 --- a/plugins/tests/service.rs +++ b/plugins/tests/service.rs @@ -1,3 +1,5 @@ +use serde::{Deserialize, Serialize}; +use std::hash::Hash; use tarpc::context; #[test] @@ -72,7 +74,6 @@ fn service_with_cfg_rpc() { #[test] fn syntax() { - #[tarpc::service] trait Syntax { #[deny(warnings)] #[allow(non_snake_case)] @@ -92,3 +93,53 @@ fn syntax() { async fn one_arg_implicit_return_error(one: String); } } + +#[test] +fn custom_derives() { + #[tarpc::service(derive = [Clone, Hash])] + trait Foo { + async fn foo(); + } + + fn requires_clone(_: impl Clone) {} + fn requires_hash(_: impl Hash) {} + + let x = FooRequest::Foo {}; + requires_clone(x.clone()); + requires_hash(x); +} + +#[test] +fn implicit_serde() { + #[tarpc::service] + trait Foo { + async fn foo(); + } + + fn requires_serde(_: T) + where + for<'de> T: Serialize + Deserialize<'de>, + { + } + + let x = FooRequest::Foo {}; + requires_serde(x); +} + +#[allow(deprecated)] +#[test] +fn explicit_serde() { + #[tarpc::service(derive_serde = true)] + trait Foo { + async fn foo(); + } + + fn requires_serde(_: T) + where + for<'de> T: Serialize + Deserialize<'de>, + { + } + + let x = FooRequest::Foo {}; + requires_serde(x); +} diff --git a/tarpc/tests/compile_fail.rs b/tarpc/tests/compile_fail.rs index c28fe2fa..bce173c1 100644 --- a/tarpc/tests/compile_fail.rs +++ b/tarpc/tests/compile_fail.rs @@ -4,4 +4,8 @@ fn ui() { t.compile_fail("tests/compile_fail/*.rs"); #[cfg(all(feature = "serde-transport", feature = "tcp"))] t.compile_fail("tests/compile_fail/serde_transport/*.rs"); + #[cfg(not(feature = "serde1"))] + t.compile_fail("tests/compile_fail/no_serde1/*.rs"); + #[cfg(feature = "serde1")] + t.compile_fail("tests/compile_fail/serde1/*.rs"); } diff --git a/tarpc/tests/compile_fail/no_serde1/no_explicit_serde_without_feature.rs b/tarpc/tests/compile_fail/no_serde1/no_explicit_serde_without_feature.rs new file mode 100644 index 00000000..9b844ba4 --- /dev/null +++ b/tarpc/tests/compile_fail/no_serde1/no_explicit_serde_without_feature.rs @@ -0,0 +1,9 @@ +#[tarpc::service(derive_serde = true)] +trait Foo { + async fn foo(); +} + +fn main() { + let x = FooRequest::Foo {}; + x.serialize(); +} diff --git a/tarpc/tests/compile_fail/no_serde1/no_explicit_serde_without_feature.stderr b/tarpc/tests/compile_fail/no_serde1/no_explicit_serde_without_feature.stderr new file mode 100644 index 00000000..27744b37 --- /dev/null +++ b/tarpc/tests/compile_fail/no_serde1/no_explicit_serde_without_feature.stderr @@ -0,0 +1,11 @@ +error: To enable serde, first enable the `serde1` feature of tarpc + --> tests/compile_fail/no_serde1/no_explicit_serde_without_feature.rs:1:18 + | +1 | #[tarpc::service(derive_serde = true)] + | ^^^^^^^^^^^^ + +error[E0433]: failed to resolve: use of undeclared type `FooRequest` + --> tests/compile_fail/no_serde1/no_explicit_serde_without_feature.rs:7:13 + | +7 | let x = FooRequest::Foo {}; + | ^^^^^^^^^^ use of undeclared type `FooRequest` diff --git a/tarpc/tests/compile_fail/no_serde1/no_implicit_serde_without_feature.rs b/tarpc/tests/compile_fail/no_serde1/no_implicit_serde_without_feature.rs new file mode 100644 index 00000000..6a86115d --- /dev/null +++ b/tarpc/tests/compile_fail/no_serde1/no_implicit_serde_without_feature.rs @@ -0,0 +1,9 @@ +#[tarpc::service] +trait Foo { + async fn foo(); +} + +fn main() { + let x = FooRequest::Foo {}; + x.serialize(); +} diff --git a/tarpc/tests/compile_fail/no_serde1/no_implicit_serde_without_feature.stderr b/tarpc/tests/compile_fail/no_serde1/no_implicit_serde_without_feature.stderr new file mode 100644 index 00000000..2704e3f0 --- /dev/null +++ b/tarpc/tests/compile_fail/no_serde1/no_implicit_serde_without_feature.stderr @@ -0,0 +1,12 @@ +error[E0599]: no method named `serialize` found for enum `FooRequest` in the current scope + --> tests/compile_fail/no_serde1/no_implicit_serde_without_feature.rs:8:7 + | +1 | #[tarpc::service] + | ----------------- method `serialize` not found for this enum +... +8 | x.serialize(); + | ^^^^^^^^^ method not found in `FooRequest` + | + = help: items from traits can only be used if the trait is implemented and in scope + = note: the following trait defines an item `serialize`, perhaps you need to implement it: + candidate #1: `serde::ser::Serialize` diff --git a/tarpc/tests/compile_fail/serde1/deprecated.rs b/tarpc/tests/compile_fail/serde1/deprecated.rs new file mode 100644 index 00000000..813cbfee --- /dev/null +++ b/tarpc/tests/compile_fail/serde1/deprecated.rs @@ -0,0 +1,8 @@ +#![deny(warnings)] + +#[tarpc::service(derive_serde = true)] +trait Foo { + async fn foo(); +} + +fn main() {} diff --git a/tarpc/tests/compile_fail/serde1/deprecated.stderr b/tarpc/tests/compile_fail/serde1/deprecated.stderr new file mode 100644 index 00000000..b4f67827 --- /dev/null +++ b/tarpc/tests/compile_fail/serde1/deprecated.stderr @@ -0,0 +1,15 @@ +error: use of deprecated constant `_::DEPRECATED_SYNTAX`: + The form `tarpc::service(derive_serde = true)` is deprecated. + Use `tarpc::service(derive = [Serialize, Deserialize])`. + --> tests/compile_fail/serde1/deprecated.rs:3:1 + | +3 | #[tarpc::service(derive_serde = true)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | +note: the lint level is defined here + --> tests/compile_fail/serde1/deprecated.rs:1:9 + | +1 | #![deny(warnings)] + | ^^^^^^^^ + = note: `#[deny(deprecated)]` implied by `#[deny(warnings)]` + = note: this error originates in the attribute macro `tarpc::service` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tarpc/tests/compile_fail/serde1/incompatible.rs b/tarpc/tests/compile_fail/serde1/incompatible.rs new file mode 100644 index 00000000..3b1df9f2 --- /dev/null +++ b/tarpc/tests/compile_fail/serde1/incompatible.rs @@ -0,0 +1,7 @@ +#![allow(deprecated)] +#[tarpc::service(derive = [Clone], derive_serde = true)] +trait Foo { + async fn foo(); +} + +fn main() {} diff --git a/tarpc/tests/compile_fail/serde1/incompatible.stderr b/tarpc/tests/compile_fail/serde1/incompatible.stderr new file mode 100644 index 00000000..035c960b --- /dev/null +++ b/tarpc/tests/compile_fail/serde1/incompatible.stderr @@ -0,0 +1,7 @@ +error: tarpc does not support `derive_serde` and `derive` at the same time + --> tests/compile_fail/serde1/incompatible.rs:2:1 + | +2 | #[tarpc::service(derive = [Clone], derive_serde = true)] + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + | + = note: this error originates in the attribute macro `tarpc::service` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/tarpc/tests/compile_fail/serde1/opt_out_serde.rs b/tarpc/tests/compile_fail/serde1/opt_out_serde.rs new file mode 100644 index 00000000..a7690f60 --- /dev/null +++ b/tarpc/tests/compile_fail/serde1/opt_out_serde.rs @@ -0,0 +1,15 @@ +#![allow(deprecated)] + +use std::fmt::Formatter; + +#[tarpc::service(derive_serde = false)] +trait Foo { + async fn foo(); +} + +fn foo(f: &mut Formatter) { + let x = FooRequest::Foo {}; + tarpc::serde::Serialize::serialize(&x, f); +} + +fn main() {} diff --git a/tarpc/tests/compile_fail/serde1/opt_out_serde.stderr b/tarpc/tests/compile_fail/serde1/opt_out_serde.stderr new file mode 100644 index 00000000..22f1e203 --- /dev/null +++ b/tarpc/tests/compile_fail/serde1/opt_out_serde.stderr @@ -0,0 +1,18 @@ +error[E0277]: the trait bound `FooRequest: Serialize` is not satisfied + --> tests/compile_fail/serde1/opt_out_serde.rs:12:40 + | +12 | tarpc::serde::Serialize::serialize(&x, f); + | ---------------------------------- ^^ the trait `Serialize` is not implemented for `FooRequest` + | | + | required by a bound introduced by this call + | + = help: the following other types implement trait `Serialize`: + bool + char + isize + i8 + i16 + i32 + i64 + i128 + and $N others