Skip to content

Commit

Permalink
Modify attributes in place and handle #[cfg]
Browse files Browse the repository at this point in the history
  • Loading branch information
jonas-schievink committed Feb 10, 2021
1 parent 0cacd37 commit 489b202
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 53 deletions.
103 changes: 50 additions & 53 deletions firmware/defmt-test/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{parse, spanned::Spanned, Block, FnArg, Ident, Item, ItemMod, Path, ReturnType, Type};
use syn::{parse, spanned::Spanned, Attribute, Item, ItemFn, ItemMod, ReturnType, Type};

#[proc_macro_attribute]
pub fn tests(args: TokenStream, input: TokenStream) -> TokenStream {
Expand Down Expand Up @@ -37,24 +37,24 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
let mut imports = vec![];
for item in items {
match item {
Item::Fn(f) => {
Item::Fn(mut f) => {
let mut test_kind = None;
let mut should_error = false;

for attr in &f.attrs {
if path_is_ident(&attr.path, "init") {
f.attrs.retain(|attr| {
if attr.path.is_ident("init") {
test_kind = Some(Attr::Init);
} else if path_is_ident(&attr.path, "test") {
false
} else if attr.path.is_ident("test") {
test_kind = Some(Attr::Test);
} else if path_is_ident(&attr.path, "should_error") {
false
} else if attr.path.is_ident("should_error") {
should_error = true;
false
} else {
return Err(parse::Error::new(
attr.span(),
"only attributes `#[test]`, `#[init]` and `#[should_error]` are accepted",
));
true
}
}
});

let attr = match test_kind {
Some(it) => it,
Expand Down Expand Up @@ -89,16 +89,12 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
));
}

let state = match f.sig.output {
let state = match &f.sig.output {
ReturnType::Default => None,
ReturnType::Type(.., ty) => Some(ty),
ReturnType::Type(.., ty) => Some(ty.clone()),
};

init = Some(Init {
block: f.block,
ident: f.sig.ident,
state,
});
init = Some(Init { func: f, state });
}

Attr::Test => {
Expand All @@ -115,10 +111,7 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
// NOTE we cannot check the argument type matches `init.state` at this
// point
if let Some(ty) = get_mutable_reference_type(arg).cloned() {
Some(Input {
arg: arg.clone(),
ty,
})
Some(Input { ty })
} else {
// was not `&mut T`
return Err(parse::Error::new(
Expand All @@ -130,16 +123,10 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
None
};

let ret_ty = match f.sig.output {
ReturnType::Default => syn::parse_str("()").unwrap(),
ReturnType::Type(_, ty) => (*ty).clone(),
};

tests.push(Test {
block: f.block,
ident: f.sig.ident,
cfgs: extract_cfgs(&f.attrs),
func: f,
input,
ret_ty,
should_error,
})
}
Expand All @@ -163,12 +150,12 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
let ident = module.ident;
let mut state_ty = None;
let (init_fn, init_expr) = if let Some(init) = init {
let init_ident = init.ident;
let init_block = init.block;
let init_func = &init.func;
let init_ident = &init.func.sig.ident;
state_ty = init.state;

(
Some(quote!(fn #init_ident() -> #state_ty #init_block)),
Some(quote!(#init_func)),
Some(quote!(#[allow(dead_code)] let mut state = #init_ident();)),
)
} else {
Expand All @@ -178,8 +165,9 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
let mut unit_test_calls = vec![];
for test in &tests {
let should_error = test.should_error;
let ident = &test.ident;
let span = ident.span();
let ident = &test.func.sig.ident;
let span = test.func.sig.ident.span();
let cfgs = &test.cfgs;
let call = if let Some(input) = test.input.as_ref() {
if let Some(state) = &state_ty {
if input.ty != **state {
Expand All @@ -200,19 +188,23 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
quote!(#ident())
};
unit_test_calls.push(quote!(
#( #cfgs )*
#krate::export::check_outcome(#call, #should_error);
));
}
let unit_test_names = tests.iter().map(|test| &test.ident);
let unit_test_inputs = tests
.iter()
.map(|test| test.input.as_ref().map(|input| &input.arg));
let unit_test_outputs = tests.iter().map(|test| &test.ret_ty);
let unit_test_blocks = tests.iter().map(|test| &test.block);

let test_functions = tests.iter().map(|test| &test.func);
let unit_test_running = tests
.iter()
.enumerate()
.map(|(i, test)| format!("({}/{}) running `{}`...", i + 1, tests.len(), test.ident))
.map(|(i, test)| {
format!(
"({}/{}) running `{}`...",
i + 1,
tests.len(),
test.func.sig.ident
)
})
.collect::<Vec<_>>();
Ok(quote!(mod #ident {
#(#imports)*
Expand All @@ -232,7 +224,7 @@ fn tests_impl(args: TokenStream, input: TokenStream) -> parse::Result<TokenStrea
#init_fn

#(
fn #unit_test_names(#unit_test_inputs) -> #unit_test_outputs #unit_test_blocks
#test_functions
)*
})
.into())
Expand All @@ -245,28 +237,21 @@ enum Attr {
}

struct Init {
block: Box<Block>,
ident: Ident,
func: ItemFn,
state: Option<Box<Type>>,
}

struct Test {
block: Box<Block>,
ident: Ident,
func: ItemFn,
cfgs: Vec<Attribute>,
input: Option<Input>,
ret_ty: Type,
should_error: bool,
}

struct Input {
arg: FnArg,
ty: Type,
}

fn path_is_ident(path: &Path, s: &str) -> bool {
path.get_ident().map(|ident| ident == s).unwrap_or(false)
}

// NOTE doesn't check the parameters or the return type
fn check_fn_sig(sig: &syn::Signature) -> Result<(), ()> {
if sig.constness.is_none()
Expand Down Expand Up @@ -298,3 +283,15 @@ fn get_mutable_reference_type(arg: &syn::FnArg) -> Option<&Type> {
None
}
}

fn extract_cfgs(attrs: &[Attribute]) -> Vec<Attribute> {
let mut cfgs = vec![];

for attr in attrs {
if attr.path.is_ident("cfg") {
cfgs.push(attr.clone());
}
}

cfgs
}
10 changes: 10 additions & 0 deletions firmware/qemu/src/bin/defmt-test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ mod tests {
use core::u8::MAX;
use defmt::{assert, assert_eq};

#[init]
fn init() {}

#[test]
fn assert_true() -> () {
assert!(true);
Expand All @@ -18,11 +21,18 @@ mod tests {
assert_eq!(255, MAX);
}

#[cfg(not(never))]
#[test]
fn result() -> Result<(), ()> {
Ok(())
}

#[cfg(never)]
#[test]
fn doesnt_compile() {
because::this::doesnt::exist();
}

#[test]
#[should_error]
fn should_error() -> Result<(), ()> {
Expand Down

0 comments on commit 489b202

Please sign in to comment.