Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: function-like macro for generating tests #10233

Merged
merged 3 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 11 additions & 12 deletions crates/storage/codecs/derive/src/arbitrary.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::DeriveInput;
use proc_macro2::{Ident, TokenStream as TokenStream2};
use quote::{quote, ToTokens};

/// If `compact` or `rlp` is passed to `derive_arbitrary`, this function will generate the
/// corresponding proptest roundtrip tests.
///
/// It accepts an optional integer number for the number of proptest cases. Otherwise, it will set
/// it at 1000.
pub fn maybe_generate_tests(args: TokenStream, ast: &DeriveInput) -> TokenStream2 {
let type_ident = ast.ident.clone();

pub fn maybe_generate_tests(
args: TokenStream,
type_ident: &impl ToTokens,
mod_tests: &Ident,
) -> TokenStream2 {
// Same as proptest
let mut default_cases = 256;

Expand All @@ -25,7 +26,7 @@ pub fn maybe_generate_tests(args: TokenStream, ast: &DeriveInput) -> TokenStream
{
let mut buf = vec![];
let len = field.clone().to_compact(&mut buf);
let (decoded, _) = super::#type_ident::from_compact(&buf, len);
let (decoded, _): (super::#type_ident, _) = Compact::from_compact(&buf, len);
assert!(field == decoded, "maybe_generate_tests::compact");
}
});
Expand All @@ -36,7 +37,7 @@ pub fn maybe_generate_tests(args: TokenStream, ast: &DeriveInput) -> TokenStream
let mut buf = vec![];
let len = field.encode(&mut buf);
let mut b = &mut buf.as_slice();
let decoded = super::#type_ident::decode(b).unwrap();
let decoded: super::#type_ident = Decodable::decode(b).unwrap();
assert_eq!(field, decoded, "maybe_generate_tests::rlp");
// ensure buffer is fully consumed by decode
assert!(b.is_empty(), "buffer was not consumed entirely");
Expand All @@ -53,7 +54,7 @@ pub fn maybe_generate_tests(args: TokenStream, ast: &DeriveInput) -> TokenStream
let mut raw = [0u8; 1024];
rand::thread_rng().fill_bytes(&mut raw);
let mut unstructured = arbitrary::Unstructured::new(&raw[..]);
let val = <super::#type_ident as arbitrary::Arbitrary>::arbitrary(&mut unstructured);
let val: Result<super::#type_ident, _> = arbitrary::Arbitrary::arbitrary(&mut unstructured);
if val.is_err() {
// this can be flaky sometimes due to not enough data for iterator based types like Vec
return
Expand All @@ -69,7 +70,7 @@ pub fn maybe_generate_tests(args: TokenStream, ast: &DeriveInput) -> TokenStream
let mut b = Vec::with_capacity(decode_buf.len());
header.encode(&mut b);
b.extend_from_slice(decode_buf);
let res = super::#type_ident::decode(&mut b.as_ref());
let res: Result<super::#type_ident, _> = Decodable::decode(&mut b.as_ref());
assert!(res.is_err(), "malformed header was decoded");
}
});
Expand All @@ -80,8 +81,6 @@ pub fn maybe_generate_tests(args: TokenStream, ast: &DeriveInput) -> TokenStream

let mut tests = TokenStream2::default();
if !roundtrips.is_empty() {
let mod_tests = format_ident!("{}Tests", ast.ident);

tests = quote! {
#[allow(non_snake_case)]
#[cfg(test)]
Expand Down
52 changes: 49 additions & 3 deletions crates/storage/codecs/derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

use proc_macro::{TokenStream, TokenTree};
use quote::{format_ident, quote};
use syn::{parse_macro_input, DeriveInput};
use syn::{
bracketed,
parse::{Parse, ParseStream},
parse_macro_input, DeriveInput, Result, Token,
};

mod arbitrary;
mod compact;
Expand Down Expand Up @@ -85,7 +89,8 @@ pub fn reth_codec(args: TokenStream, input: TokenStream) -> TokenStream {
pub fn derive_arbitrary(args: TokenStream, input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);

let tests = arbitrary::maybe_generate_tests(args, &ast);
let tests =
arbitrary::maybe_generate_tests(args, &ast.ident, &format_ident!("{}Tests", ast.ident));

// Avoid duplicate names
let arb_import = format_ident!("{}Arbitrary", ast.ident);
Expand All @@ -106,10 +111,51 @@ pub fn derive_arbitrary(args: TokenStream, input: TokenStream) -> TokenStream {
#[proc_macro_attribute]
pub fn add_arbitrary_tests(args: TokenStream, input: TokenStream) -> TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
let tests = arbitrary::maybe_generate_tests(args, &ast);

let tests =
arbitrary::maybe_generate_tests(args, &ast.ident, &format_ident!("{}Tests", ast.ident));
quote! {
#ast
#tests
}
.into()
}

struct GenerateTestsInput {
args: TokenStream,
ty: syn::Type,
mod_name: syn::Ident,
}

impl Parse for GenerateTestsInput {
fn parse(input: ParseStream<'_>) -> Result<Self> {
input.parse::<Token![#]>()?;

let args;
bracketed!(args in input);

let args = args.parse::<proc_macro2::TokenStream>()?;
let ty = input.parse()?;

input.parse::<Token![,]>()?;
let mod_name = input.parse()?;

Ok(Self { args: args.into(), ty, mod_name })
}
}

/// Generates tests for given type based on passed parameters.
///
/// See `arbitrary::maybe_generate_tests` for more information.
///
/// Examples:
/// * `generate_tests!(#[rlp] MyType, MyTypeTests)`: will generate rlp roundtrip tests for `MyType`
/// in a module named `MyTypeTests`.
/// * `generate_tests!(#[compact, 10] MyType, MyTypeTests)`: will generate compact roundtrip tests
/// for `MyType` limited to 10 cases.
#[proc_macro]
pub fn generate_tests(input: TokenStream) -> TokenStream {
mattsse marked this conversation as resolved.
Show resolved Hide resolved
let input = parse_macro_input!(input as GenerateTestsInput);

arbitrary::maybe_generate_tests(input.args, &input.ty, &input.mod_name).into()
}
Loading