|
3 | 3 | use proc_macro::TokenStream; |
4 | 4 | use proc_macro2::{Span, TokenStream as TokenStream2}; |
5 | 5 | use quote::quote; |
6 | | -use syn::{parse_macro_input, Data, DeriveInput, Fields, Generics, Ident, Index, TypeParamBound}; |
| 6 | +use syn::{ |
| 7 | + parse_macro_input, Data, DeriveInput, Fields, Generics, Ident, Index, ItemFn, TypeParamBound, |
| 8 | +}; |
7 | 9 |
|
8 | 10 | #[proc_macro_derive(CheapClone)] |
9 | 11 | pub fn derive_cheap_clone(input: TokenStream) -> TokenStream { |
@@ -235,6 +237,68 @@ pub fn derive_cache_weight(input: TokenStream) -> TokenStream { |
235 | 237 | TokenStream::from(expanded) |
236 | 238 | } |
237 | 239 |
|
| 240 | +/// A proc macro attribute similar to `tokio::test` but uses the |
| 241 | +/// `TEST_RUNTIME` instead of creating a new runtime for each test. |
| 242 | +/// |
| 243 | +/// # Example |
| 244 | +/// |
| 245 | +/// ```ignore |
| 246 | +/// use graph::prelude::*; |
| 247 | +/// |
| 248 | +/// #[graph::test] |
| 249 | +/// async fn my_test() { |
| 250 | +/// // Test code here |
| 251 | +/// } |
| 252 | +/// ``` |
| 253 | +/// |
| 254 | +/// The macro transforms the async test function to use |
| 255 | +/// `TEST_RUNTIME.block_on()`. |
| 256 | +/// |
| 257 | +/// Note that for tests in the `graph` crate itself, the macro must be used |
| 258 | +/// as `#[crate::test]` |
| 259 | +#[proc_macro_attribute] |
| 260 | +pub fn test(args: TokenStream, item: TokenStream) -> TokenStream { |
| 261 | + let input = parse_macro_input!(item as ItemFn); |
| 262 | + |
| 263 | + if !args.is_empty() { |
| 264 | + let msg = "the `#[graph::test]` attribute does not take any arguments"; |
| 265 | + return syn::Error::new(Span::call_site(), msg) |
| 266 | + .to_compile_error() |
| 267 | + .into(); |
| 268 | + } |
| 269 | + |
| 270 | + let ret = &input.sig.output; |
| 271 | + let name = &input.sig.ident; |
| 272 | + let body = &input.block; |
| 273 | + let attrs = &input.attrs; |
| 274 | + let vis = &input.vis; |
| 275 | + |
| 276 | + if input.sig.asyncness.is_none() { |
| 277 | + let msg = "the `async` keyword is missing from the function declaration"; |
| 278 | + return syn::Error::new_spanned(&input.sig.fn_token, msg) |
| 279 | + .to_compile_error() |
| 280 | + .into(); |
| 281 | + } |
| 282 | + |
| 283 | + let crate_name = std::env::var("CARGO_CRATE_NAME").unwrap(); |
| 284 | + let pkg_name = std::env::var("CARGO_PKG_NAME").unwrap(); |
| 285 | + let runtime = if crate_name == "graph" && pkg_name == "graph" { |
| 286 | + quote! { crate::tokio::TEST_RUNTIME } |
| 287 | + } else { |
| 288 | + quote! { graph::TEST_RUNTIME } |
| 289 | + }; |
| 290 | + |
| 291 | + let expanded = quote! { |
| 292 | + #[::core::prelude::v1::test] |
| 293 | + #(#attrs)* |
| 294 | + #vis fn #name() #ret { |
| 295 | + #runtime.block_on(async #body) |
| 296 | + } |
| 297 | + }; |
| 298 | + |
| 299 | + TokenStream::from(expanded) |
| 300 | +} |
| 301 | + |
238 | 302 | #[cfg(test)] |
239 | 303 | mod tests { |
240 | 304 | use proc_macro_utils::assert_expansion; |
|
0 commit comments