diff --git a/rstest/src/lib.rs b/rstest/src/lib.rs index 8ec59d4..a4ef046 100644 --- a/rstest/src/lib.rs +++ b/rstest/src/lib.rs @@ -1469,3 +1469,19 @@ pub use rstest_macros::fixture; /// ``` /// pub use rstest_macros::rstest; + +pub struct Context { + pub name: &'static str, + pub description: Option<&'static str>, + pub case: Option, +} + +impl Context { + pub fn new(name: &'static str, description: Option<&'static str>, case: Option) -> Self { + Self { + name, + description, + case, + } + } +} diff --git a/rstest/tests/resources/rstest/context.rs b/rstest/tests/resources/rstest/context.rs new file mode 100644 index 0000000..fb629b5 --- /dev/null +++ b/rstest/tests/resources/rstest/context.rs @@ -0,0 +1,16 @@ +use rstest::*; + +#[rstest] +#[case::description(42)] +fn with_case(#[context] ctx: Context, #[case] _c: u32) { + assert_eq!("with_case", ctx.name); + assert_eq!(Some("description"), ctx.description); + assert_eq!(Some(0), ctx.case); +} + +#[rstest] +fn without_case(#[context] ctx: Context) { + assert_eq!("without_case", ctx.name); + assert_eq!(None, ctx.description); + assert_eq!(None, ctx.case); +} diff --git a/rstest/tests/rstest/mod.rs b/rstest/tests/rstest/mod.rs index 2ab06dc..01390cd 100644 --- a/rstest/tests/rstest/mod.rs +++ b/rstest/tests/rstest/mod.rs @@ -1216,6 +1216,16 @@ fn no_std() { .assert(output); } +#[test] +fn context() { + let (output, _) = run_test("context.rs"); + + TestResults::new() + .ok("with_case::case_1_description") + .ok("without_case") + .assert(output); +} + mod async_timeout_feature { use super::*; diff --git a/rstest_macros/src/parse/arguments.rs b/rstest_macros/src/parse/arguments.rs index ab2b94b..cb593a9 100644 --- a/rstest_macros/src/parse/arguments.rs +++ b/rstest_macros/src/parse/arguments.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use quote::format_ident; use syn::{FnArg, Ident, Pat}; @@ -98,6 +98,7 @@ pub(crate) struct ArgumentsInfo { args: Args, is_global_await: bool, once: Option, + contexts: HashSet, } impl ArgumentsInfo { @@ -235,6 +236,19 @@ impl ArgumentsInfo { fn_arg }) } + + #[allow(dead_code)] + pub(crate) fn add_context(&mut self, pat: Pat) { + self.contexts.insert(pat); + } + + pub(crate) fn set_contexts(&mut self, contexts: impl Iterator) { + contexts.for_each(|c| self.add_context(c)) + } + + pub(crate) fn contexts(&self) -> impl Iterator + '_ { + self.contexts.iter() + } } #[cfg(test)] diff --git a/rstest_macros/src/parse/context.rs b/rstest_macros/src/parse/context.rs new file mode 100644 index 0000000..04d371d --- /dev/null +++ b/rstest_macros/src/parse/context.rs @@ -0,0 +1,60 @@ +use syn::{visit_mut::VisitMut, ItemFn, Pat}; + +use crate::error::ErrorsVec; + +use super::just_once::JustOnceFnArgAttributeExtractor; + +pub(crate) fn extract_context(item_fn: &mut ItemFn) -> Result, ErrorsVec> { + let mut extractor = JustOnceFnArgAttributeExtractor::from("context"); + extractor.visit_item_fn_mut(item_fn); + extractor.take() +} + +#[cfg(test)] +mod should { + use super::*; + use crate::test::{assert_eq, *}; + use rstest_test::assert_in; + + #[rstest] + #[case("fn simple(a: u32) {}")] + #[case("fn more(a: u32, b: &str) {}")] + #[case("fn gen>(a: u32, b: S) {}")] + #[case("fn attr(#[case] a: u32, #[values(1,2)] b: i32) {}")] + fn not_change_anything_if_no_ignore_attribute_found(#[case] item_fn: &str) { + let mut item_fn: ItemFn = item_fn.ast(); + let orig = item_fn.clone(); + + let by_refs = extract_context(&mut item_fn).unwrap(); + + assert_eq!(orig, item_fn); + assert!(by_refs.is_empty()); + } + + #[rstest] + #[case::simple("fn f(#[context] a: u32) {}", "fn f(a: u32) {}", &["a"])] + #[case::more_than_one( + "fn f(#[context] a: u32, #[context] b: String, #[context] c: std::collection::HashMap) {}", + r#"fn f(a: u32, + b: String, + c: std::collection::HashMap) {}"#, + &["a", "b", "c"])] + fn extract(#[case] item_fn: &str, #[case] expected: &str, #[case] expected_refs: &[&str]) { + let mut item_fn: ItemFn = item_fn.ast(); + let expected: ItemFn = expected.ast(); + + let by_refs = extract_context(&mut item_fn).unwrap(); + + assert_eq!(expected, item_fn); + assert_eq!(by_refs, to_pats!(expected_refs)); + } + + #[test] + fn raise_error() { + let mut item_fn: ItemFn = "fn f(#[context] #[context] a: u32) {}".ast(); + + let err = extract_context(&mut item_fn).unwrap_err(); + + assert_in!(format!("{:?}", err), "more than once"); + } +} diff --git a/rstest_macros/src/parse/mod.rs b/rstest_macros/src/parse/mod.rs index baef9e6..e6eb156 100644 --- a/rstest_macros/src/parse/mod.rs +++ b/rstest_macros/src/parse/mod.rs @@ -28,6 +28,7 @@ pub(crate) mod macros; pub(crate) mod arguments; pub(crate) mod by_ref; +pub(crate) mod context; pub(crate) mod expressions; pub(crate) mod fixture; pub(crate) mod future; diff --git a/rstest_macros/src/parse/rstest.rs b/rstest_macros/src/parse/rstest.rs index 4ee2dea..9d62c29 100644 --- a/rstest_macros/src/parse/rstest.rs +++ b/rstest_macros/src/parse/rstest.rs @@ -8,8 +8,9 @@ use self::files::{extract_files, ValueListFromFiles}; use super::{ arguments::ArgumentsInfo, by_ref::extract_by_ref, - check_timeout_attrs, extract_case_args, extract_cases, extract_excluded_trace, - extract_fixtures, extract_value_list, + check_timeout_attrs, + context::extract_context, + extract_case_args, extract_cases, extract_excluded_trace, extract_fixtures, extract_value_list, future::{extract_futures, extract_global_awt}, ignore::extract_ignores, parse_vector_trailing_till_double_comma, @@ -49,20 +50,24 @@ impl Parse for RsTestInfo { impl ExtendWithFunctionAttrs for RsTestInfo { fn extend_with_function_attrs(&mut self, item_fn: &mut ItemFn) -> Result<(), ErrorsVec> { - let composed_tuple!(_inner, excluded, _timeout, futures, global_awt, by_refs, ignores) = merge_errors!( + let composed_tuple!( + _inner, excluded, _timeout, futures, global_awt, by_refs, ignores, contexts + ) = merge_errors!( self.data.extend_with_function_attrs(item_fn), extract_excluded_trace(item_fn), check_timeout_attrs(item_fn), extract_futures(item_fn), extract_global_awt(item_fn), extract_by_ref(item_fn), - extract_ignores(item_fn) + extract_ignores(item_fn), + extract_context(item_fn) )?; self.attributes.add_notraces(excluded); self.arguments.set_global_await(global_awt); self.arguments.set_futures(futures.into_iter()); self.arguments.set_by_refs(by_refs.into_iter()); self.arguments.set_ignores(ignores.into_iter()); + self.arguments.set_contexts(contexts.into_iter()); self.arguments .register_inner_destructored_idents_names(item_fn); Ok(()) @@ -379,6 +384,8 @@ mod test { } mod no_cases { + use std::collections::HashSet; + use super::{assert_eq, *}; #[test] @@ -563,6 +570,25 @@ mod test { assert!(info.arguments.is_future(&pat("a"))); assert!(!info.arguments.is_future(&pat("b"))); } + + #[rstest] + fn extract_context() { + let mut item_fn = + "fn f(#[context] c: Context, #[context] other: Context, more: u32) {}".ast(); + let expected = "fn f(c: Context, other: Context, more: u32) {}".ast(); + + let mut info = RsTestInfo::default(); + + info.extend_with_function_attrs(&mut item_fn).unwrap(); + + assert_eq!(item_fn, expected); + assert_eq!( + info.arguments.contexts().cloned().collect::>(), + vec![pat("c"), pat("other")] + .into_iter() + .collect::>() + ); + } } mod parametrize_cases { diff --git a/rstest_macros/src/render/mod.rs b/rstest_macros/src/render/mod.rs index f5cc9c0..e0e62ff 100644 --- a/rstest_macros/src/render/mod.rs +++ b/rstest_macros/src/render/mod.rs @@ -55,6 +55,7 @@ pub(crate) fn single(mut test: ItemFn, mut info: RsTestInfo) -> TokenStream { resolver, &info, &test.sig.generics, + &None, ) } @@ -65,8 +66,13 @@ pub(crate) fn parametrize(mut test: ItemFn, info: RsTestInfo) -> TokenStream { let resolver_fixtures = resolver::fixtures::get(&info.arguments, info.data.fixtures()); let rendered_cases = cases_data(&info, test.sig.ident.span()) - .map(|(name, attrs, resolver)| { - TestCaseRender::new(name, attrs, (resolver, &resolver_fixtures)) + .map(|c| { + CaseDataValues::new( + c.ident, + c.attributes, + Box::new((c.resolver, &resolver_fixtures)), + c.info, + ) }) .map(|case| case.render(&test, &info)) .collect(); @@ -81,11 +87,14 @@ impl ValueList { resolver: &dyn Resolver, attrs: &[syn::Attribute], info: &RsTestInfo, + case_info: &Option, ) -> TokenStream { let span = test.sig.ident.span(); let test_cases = self .argument_data(resolver, info) - .map(|(name, r)| TestCaseRender::new(Ident::new(&name, span), attrs, r)) + .map(|(name, r)| { + CaseDataValues::new(Ident::new(&name, span), attrs, r, case_info.clone()) + }) .map(|test_case| test_case.render(test, info)); quote! { #(#test_cases)* } @@ -118,12 +127,25 @@ impl ValueList { } } +#[derive(Clone, Debug)] +struct CaseInfo { + description: Option, + pos: usize, +} + +impl CaseInfo { + fn new(description: Option, pos: usize) -> Self { + Self { description, pos } + } +} + fn _matrix_recursive<'a>( test: &ItemFn, list_values: &'a [&'a ValueList], resolver: &dyn Resolver, attrs: &'a [syn::Attribute], info: &RsTestInfo, + case_info: &Option, ) -> TokenStream { if list_values.is_empty() { return Default::default(); @@ -136,13 +158,13 @@ fn _matrix_recursive<'a>( attrs.push(parse_quote!( #[allow(non_snake_case)] )); - vlist.render(test, resolver, &attrs, info) + vlist.render(test, resolver, &attrs, info, case_info) } else { let span = test.sig.ident.span(); let modules = vlist .argument_data(resolver, info) .map(move |(name, resolver)| { - _matrix_recursive(test, list_values, &resolver, attrs, info) + _matrix_recursive(test, list_values, &resolver, attrs, info, case_info) .wrap_by_mod(&Ident::new(&name, span)) }); @@ -162,20 +184,21 @@ pub(crate) fn matrix(mut test: ItemFn, mut info: RsTestInfo) -> TokenStream { let resolver = resolver::fixtures::get(&info.arguments, info.data.fixtures()); let rendered_cases = if cases.is_empty() { let list_values = info.data.list_values().collect::>(); - _matrix_recursive(&test, &list_values, &resolver, &[], &info) + _matrix_recursive(&test, &list_values, &resolver, &[], &info, &None) } else { cases .into_iter() - .map(|(case_name, attrs, case_resolver)| { + .map(|c| { let list_values = info.data.list_values().collect::>(); _matrix_recursive( &test, &list_values, - &(case_resolver, &resolver), - attrs, + &(&c.resolver, &resolver), + c.attributes, &info, + &c.info, ) - .wrap_by_mod(&case_name) + .wrap_by_mod(&c.ident) }) .collect() }; @@ -254,6 +277,7 @@ fn single_test_case( resolver: impl Resolver, info: &RsTestInfo, generics: &syn::Generics, + case_info: &Option, ) -> TokenStream { let (attrs, trace_me): (Vec<_>, Vec<_>) = attrs.iter().cloned().partition(|a| !attr_is(a, "trace")); @@ -273,8 +297,37 @@ fn single_test_case( Some(pat) => !info.arguments.is_ignore(pat), None => true, }); + let test_fn_name_str = testfn_name.to_string(); + let description = match case_info + .as_ref() + .and_then(|c| c.description.as_ref()) + .map(|d| d.to_string()) + { + Some(s) => quote! { Some(#s) }, + None => quote! { None }, + }; + let pos = match case_info.as_ref().map(|c| c.pos) { + Some(p) => quote! { Some(#p) }, + None => quote! { None }, + }; + let context_resolver = info + .arguments + .contexts() + .map(|p| { + (p.clone(), { + let e: Expr = parse_quote! { + Context::new(#test_fn_name_str, #description, #pos) + }; + e + }) + }) + .collect::>(); - let inject = inject::resolve_arguments(injectable_args.into_iter(), &resolver, &generics_types); + let inject = inject::resolve_arguments( + injectable_args.into_iter(), + &(context_resolver, &resolver), + &generics_types, + ); let args = args .iter() @@ -354,29 +407,15 @@ fn trace_arguments<'a>( } } -struct TestCaseRender<'a> { - name: Ident, - attrs: &'a [syn::Attribute], - resolver: Box, -} - -impl<'a> TestCaseRender<'a> { - pub fn new(name: Ident, attrs: &'a [syn::Attribute], resolver: R) -> Self { - TestCaseRender { - name, - attrs, - resolver: Box::new(resolver), - } - } - +impl<'a> CaseDataValues<'a> { fn render(self, testfn: &ItemFn, info: &RsTestInfo) -> TokenStream { let args = testfn.sig.inputs.iter().cloned().collect::>(); let mut attrs = testfn.attrs.clone(); - attrs.extend(self.attrs.iter().cloned()); + attrs.extend(self.attributes.iter().cloned()); let asyncness = testfn.sig.asyncness; single_test_case( - &self.name, + &self.ident, &testfn.sig.ident, &args, &attrs, @@ -386,6 +425,7 @@ impl<'a> TestCaseRender<'a> { self.resolver, info, &testfn.sig.generics, + &self.info, ) } } @@ -426,10 +466,30 @@ fn format_case_name(case: &TestCase, index: usize, display_len: usize) -> String format!("case_{index:0display_len$}{description}") } -fn cases_data( - info: &RsTestInfo, - name_span: Span, -) -> impl Iterator)> { +struct CaseDataValues<'a> { + ident: Ident, + attributes: &'a [syn::Attribute], + resolver: Box, + info: Option, +} + +impl<'a> CaseDataValues<'a> { + fn new( + ident: Ident, + attributes: &'a [syn::Attribute], + resolver: Box, + info: Option, + ) -> Self { + Self { + ident, + attributes, + resolver, + info, + } + } +} + +fn cases_data(info: &RsTestInfo, name_span: Span) -> impl Iterator> { let display_len = info.data.cases().count().display_len(); info.data.cases().enumerate().map({ move |(n, case)| { @@ -440,10 +500,11 @@ fn cases_data( .map(|arg| info.arguments.inner_pat(&arg).clone()) .zip(case.args.iter()) .collect::>(); - ( + CaseDataValues::new( Ident::new(&format_case_name(case, n + 1, display_len), name_span), case.attrs.as_slice(), - resolver_case, + Box::new(resolver_case), + Some(CaseInfo::new(case.description.clone(), n)), ) } }) diff --git a/rstest_macros/src/render/test.rs b/rstest_macros/src/render/test.rs index 43779ed..f582a8c 100644 --- a/rstest_macros/src/render/test.rs +++ b/rstest_macros/src/render/test.rs @@ -339,6 +339,7 @@ mod single_test_should { } } +#[derive(Debug)] struct TestsGroup { requested_test: ItemFn, module: ItemMod, @@ -1056,6 +1057,45 @@ mod cases_should { assert_in!(code, await_argument_code_string("b")); assert_not_in!(code, await_argument_code_string("c")); } + + #[test] + fn render_context() { + let (item_fn, mut info) = + TestCaseBuilder::from(r#"fn test_with_context(ctx: Context, a: u32) {}"#) + .push_case(TestCase { + args: vec![expr("1")], + attrs: Default::default(), + description: Some(ident("my_description")), + }) + .push_case(TestCase { + args: vec![expr("2")], + attrs: Default::default(), + description: Some(ident("other_description")), + }) + .take(); + info.arguments.add_context(pat("ctx")); + + let tokens = parametrize(item_fn, info); + + let tests = TestsGroup::from(tokens); + + fn code(tests: &TestsGroup, id: usize) -> String { + tests.module.get_all_tests()[id].block.display_code() + } + + assert_in!( + code(&tests, 0), + r#"let ctx = Context::new("test_with_context", Some("my_description"), Some(0usize));"# + .ast::() + .display_code() + ); + assert_in!( + code(&tests, 1), + r#"let ctx = Context::new("test_with_context", Some("other_description"), Some(1usize));"# + .ast::() + .display_code() + ); + } } mod matrix_cases_should {