diff --git a/src/proptest_fn.rs b/src/proptest_fn.rs index ad285b6..639a6ee 100644 --- a/src/proptest_fn.rs +++ b/src/proptest_fn.rs @@ -2,8 +2,8 @@ use crate::syn_utils::{Arg, Args}; use proc_macro2::TokenStream; use quote::{quote, ToTokens}; use syn::{ - parse2, parse_quote, parse_str, spanned::Spanned, token, Block, Field, FieldMutability, FnArg, - Ident, ItemFn, LitStr, Pat, Result, Visibility, + parse2, parse_quote, parse_str, spanned::Spanned, token, Block, Expr, Field, FieldMutability, + FnArg, Ident, ItemFn, LitStr, Pat, Result, Visibility, }; pub fn build_proptest(attr: TokenStream, mut item_fn: ItemFn) -> Result { @@ -130,16 +130,17 @@ impl TestFnArg { } } -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] enum Async { Tokio, + Expr(Expr), } impl Async { fn apply(&self, block: &Block) -> TokenStream { match self { Async::Tokio => { quote! { - let ret: ::core::result::Result<_, proptest::test_runner::TestCaseError> = + let ret: ::core::result::Result<_, ::proptest::test_runner::TestCaseError> = tokio::runtime::Runtime::new() .unwrap() .block_on(async move { @@ -149,15 +150,29 @@ impl Async { ret?; } } + Async::Expr(expr) => { + quote! { + let ret: ::core::result::Result<(), ::proptest::test_runner::TestCaseError> = + (#expr)(async move { + #block + Ok(()) + }); + ret?; + } + } } } } impl syn::parse::Parse for Async { fn parse(input: syn::parse::ParseStream) -> Result { - let s: LitStr = input.parse()?; - match s.value().as_str() { - "tokio" => Ok(Async::Tokio), - _ => bail!(s.span(), "expected `tokio`."), + if input.peek(LitStr) { + let s: LitStr = input.parse()?; + match s.value().as_str() { + "tokio" => Ok(Async::Tokio), + _ => bail!(s.span(), "expected `tokio`."), + } + } else { + Ok(Async::Expr(input.parse()?)) } } } diff --git a/tests/proptest_fn.rs b/tests/proptest_fn.rs index 81b08cb..e796ca6 100644 --- a/tests/proptest_fn.rs +++ b/tests/proptest_fn.rs @@ -1,5 +1,10 @@ +use ::std::result::Result; +use std::{future::Future, rc::Rc}; + +use ::proptest::test_runner::TestCaseError; use proptest::{prelude::ProptestConfig, prop_assert}; use test_strategy::proptest; +use tokio::task::yield_now; #[proptest] fn example(_x: u32, #[strategy(1..10u32)] y: u32, #[strategy(0..#y)] z: u32) { @@ -72,3 +77,42 @@ async fn tokio_test_no_copy_arg(#[strategy("a+")] s: String) { async fn tokio_test_prop_assert() { prop_assert!(true); } + +#[should_panic] +#[proptest(async = "tokio")] +async fn tokio_test_prop_assert_false() { + prop_assert!(false); +} + +fn tokio_ct(future: impl Future>) -> Result<(), TestCaseError> { + tokio::runtime::Builder::new_current_thread() + .build() + .unwrap() + .block_on(future) +} + +#[proptest(async = tokio_ct)] +async fn async_expr() {} + +#[proptest(async = tokio_ct)] +async fn async_expr_non_send() { + let x = Rc::new(0); + yield_now().await; + drop(x); +} + +#[proptest(async = tokio_ct)] +async fn async_expr_no_copy_arg(#[strategy("a+")] s: String) { + prop_assert!(s.contains('a')); +} + +#[proptest(async = tokio_ct)] +async fn async_expr_test_prop_assert() { + prop_assert!(true); +} + +#[should_panic] +#[proptest(async = tokio_ct)] +async fn async_expr_test_prop_assert_false() { + prop_assert!(false); +}