Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sol-macro): add opt-in attributes for extra methods and derives #250

Merged
merged 3 commits into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 67 additions & 20 deletions crates/sol-macro/src/attr.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use heck::{ToKebabCase, ToLowerCamelCase, ToShoutySnakeCase, ToSnakeCase, ToUpperCamelCase};
use syn::{Attribute, Error, LitStr, Result};
use syn::{Attribute, Error, LitBool, LitStr, Result};

const DUPLICATE_ERROR: &str = "duplicate attribute";
const UNKNOWN_ERROR: &str = "unknown `sol` attribute";

pub fn docs(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
attrs.iter().filter(|attr| attr.path().is_ident("doc"))
Expand All @@ -9,17 +12,20 @@ pub fn derives(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
attrs.iter().filter(|attr| attr.path().is_ident("derive"))
}

/// `#[sol(...)]` attributes.
///
/// When adding a new attribute:
/// 1. add a field to this struct,
/// 2. add a match arm in the `parse` function below,
/// 3. add test cases in the `tests` module at the bottom of this file,
/// 4. implement the attribute in the `expand` module,
/// 5. document the attribute in the [`crate::sol!`] macro docs.
// When adding a new attribute:
// 1. add a field to this struct,
// 2. add a match arm in the `parse` function below,
// 3. add test cases in the `tests` module at the bottom of this file,
// 4. implement the attribute in the `expand` module,
// 5. document the attribute in the [`crate::sol!`] macro docs.

/// `#[sol(...)]` attributes. See [`crate::sol!`] for a list of all possible
/// attributes.
#[derive(Debug, Default, PartialEq, Eq)]
pub struct SolAttrs {
pub all_derives: Option<()>,
pub all_derives: Option<bool>,
pub extra_methods: Option<bool>,

// TODO: Implement
pub rename: Option<LitStr>,
// TODO: Implement
Expand Down Expand Up @@ -51,17 +57,29 @@ impl SolAttrs {
match s.as_str() {
$(
stringify!($l) => if this.$l.is_some() {
return Err(meta.error("duplicate attribute"))
return Err(meta.error(DUPLICATE_ERROR))
} else {
this.$l = Some($e);
},
)*
_ => return Err(meta.error("unknown `sol` attribute")),
_ => return Err(meta.error(UNKNOWN_ERROR)),
}
};
}

// `path` => true, `path = <bool>` => <bool>
let bool = || {
if let Ok(input) = meta.value() {
input.parse::<LitBool>().map(|lit| lit.value)
} else {
Ok(true)
}
};

// `path = "<str>"`
let lit = || meta.value()?.parse::<LitStr>();

// `path = "0x<hex>"`
let bytes = || {
let lit = lit()?;
let v = lit.value();
Expand All @@ -76,7 +94,9 @@ impl SolAttrs {
};

match_! {
all_derives => (),
all_derives => bool()?,
extra_methods => bool()?,

rename => lit()?,
rename_all => CasingStyle::from_lit(&lit()?)?,

Expand Down Expand Up @@ -154,6 +174,13 @@ mod tests {
use syn::parse_quote;

macro_rules! test_sol_attrs {
($($group:ident { $($t:tt)* })+) => {$(
#[test]
fn $group() {
test_sol_attrs! { $($t)* }
}
)+};

($( $(#[$attr:meta])* => $expected:expr ),+ $(,)?) => {$(
run_test(
&[$(stringify!(#[$attr])),*],
Expand Down Expand Up @@ -189,15 +216,20 @@ mod tests {
.collect();
match (SolAttrs::parse(&attrs), expected) {
(Ok((actual, _)), Ok(expected)) => assert_eq!(actual, expected, "{attrs_s:?}"),
(Err(actual), Err(expected)) => assert_eq!(actual.to_string(), expected, "{attrs_s:?}"),
(Err(actual), Err(expected)) => {
if !expected.is_empty() {
assert_eq!(actual.to_string(), expected, "{attrs_s:?}")
}
}
(a, b) => panic!("assertion failed: `{a:?} != {b:?}`: {attrs_s:?}"),
}
}

#[test]
fn sol_attrs() {
test_sol_attrs! {
test_sol_attrs! {
top_level {
#[cfg] => Ok(SolAttrs::default()),
#[cfg()] => Ok(SolAttrs::default()),
#[cfg = ""] => Ok(SolAttrs::default()),
#[derive()] #[sol()] => Ok(SolAttrs::default()),
#[sol()] => Ok(SolAttrs::default()),
#[sol()] #[sol()] => Ok(SolAttrs::default()),
Expand All @@ -206,17 +238,32 @@ mod tests {

#[sol(() = "")] => Err("unexpected token in nested attribute, expected ident"),
#[sol(? = "")] => Err("unexpected token in nested attribute, expected ident"),
#[sol(::a)] => Err("expected ident"),
#[sol(::a = "")] => Err("expected ident"),
#[sol(a::b = "")] => Err("expected ident"),
}

extra {
#[sol(all_derives)] => Ok(sol_attrs! { all_derives: true }),
#[sol(all_derives = true)] => Ok(sol_attrs! { all_derives: true }),
#[sol(all_derives = false)] => Ok(sol_attrs! { all_derives: false }),
#[sol(all_derives = "false")] => Err("expected boolean literal"),
#[sol(all_derives)] #[sol(all_derives)] => Err(DUPLICATE_ERROR),

#[sol(all_derives)] => Ok(sol_attrs! { all_derives: () }),
#[sol(all_derives)] #[sol(all_derives)] => Err("duplicate attribute"),
#[sol(extra_methods)] => Ok(sol_attrs! { extra_methods: true }),
#[sol(extra_methods = true)] => Ok(sol_attrs! { extra_methods: true }),
#[sol(extra_methods = false)] => Ok(sol_attrs! { extra_methods: false }),
}

rename {
#[sol(rename = "foo")] => Ok(sol_attrs! { rename: parse_quote!("foo") }),

#[sol(rename_all = "foo")] => Err("unsupported casing: foo"),
#[sol(rename_all = "camelcase")] => Ok(sol_attrs! { rename_all: CasingStyle::Camel }),
#[sol(rename_all = "camelCase")] #[sol(rename_all = "PascalCase")] => Err("duplicate attribute"),
#[sol(rename_all = "camelCase")] #[sol(rename_all = "PascalCase")] => Err(DUPLICATE_ERROR),
}

bytecode {
#[sol(deployed_bytecode = "0x1234")] => Ok(sol_attrs! { deployed_bytecode: parse_quote!("1234") }),
#[sol(bytecode = "0x1234")] => Ok(sol_attrs! { bytecode: parse_quote!("1234") }),
#[sol(bytecode = "1234")] => Ok(sol_attrs! { bytecode: parse_quote!("1234") }),
Expand Down
74 changes: 52 additions & 22 deletions crates/sol-macro/src/expand/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, contract: &ItemContract) -> Result<TokenS
} = contract;

let (sol_attrs, attrs) = crate::attr::SolAttrs::parse(attrs)?;
let extra_methods = sol_attrs
.extra_methods
.or(cx.attrs.extra_methods)
.unwrap_or(false);

let bytecode = sol_attrs.bytecode.map(|lit| {
let name = Ident::new("BYTECODE", lit.span());
Expand Down Expand Up @@ -66,21 +70,21 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, contract: &ItemContract) -> Result<TokenS
let mut attrs = d_attrs.clone();
let doc_str = format!("Container for all the `{name}` function calls.");
attrs.push(parse_quote!(#[doc = #doc_str]));
CallLikeExpander::from_functions(cx, name, functions).expand(attrs)
CallLikeExpander::from_functions(cx, name, functions).expand(attrs, extra_methods)
});

let errors_enum = (errors.len() > 1).then(|| {
let mut attrs = d_attrs.clone();
let doc_str = format!("Container for all the `{name}` custom errors.");
attrs.push(parse_quote!(#[doc = #doc_str]));
CallLikeExpander::from_errors(cx, name, errors).expand(attrs)
CallLikeExpander::from_errors(cx, name, errors).expand(attrs, extra_methods)
});

let events_enum = (events.len() > 1).then(|| {
let mut attrs = d_attrs;
let doc_str = format!("Container for all the `{name}` events.");
attrs.push(parse_quote!(#[doc = #doc_str]));
CallLikeExpander::from_events(cx, name, events).expand_event(attrs)
CallLikeExpander::from_events(cx, name, events).expand_event(attrs, extra_methods)
});

let mod_attrs = attr::docs(&attrs);
Expand All @@ -92,6 +96,7 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, contract: &ItemContract) -> Result<TokenS
#deployed_bytecode

#item_tokens

#functions_enum
#errors_enum
#events_enum
Expand All @@ -117,10 +122,21 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, contract: &ItemContract) -> Result<TokenS
/// }
///
/// impl #name {
/// pub const SELECTORS: &'static [[u8; _]] = &[...];
/// }
///
/// #if extra_methods
/// #(
/// impl From<#types> for #name { ... }
/// impl TryFrom<#name> for #types { ... }
/// )*
///
/// impl #name {
/// #(
/// pub fn #is_variant,#as_variant,#as_variant_mut(...) -> ... { ... }
/// )*
/// }
/// #endif
/// ```
struct CallLikeExpander<'a> {
cx: &'a ExpCtxt<'a>,
Expand Down Expand Up @@ -219,7 +235,7 @@ impl<'a> CallLikeExpander<'a> {
}
}

fn expand(self, attrs: Vec<Attribute>) -> TokenStream {
fn expand(self, attrs: Vec<Attribute>, extra_methods: bool) -> TokenStream {
let Self {
name,
variants,
Expand All @@ -232,7 +248,7 @@ impl<'a> CallLikeExpander<'a> {
assert_eq!(variants.len(), types.len());
let name_s = name.to_string();
let count = variants.len();
let def = self.generate_enum(attrs);
let def = self.generate_enum(attrs, extra_methods);
quote! {
#def

Expand Down Expand Up @@ -302,12 +318,12 @@ impl<'a> CallLikeExpander<'a> {
}
}

fn expand_event(self, attrs: Vec<Attribute>) -> TokenStream {
fn expand_event(self, attrs: Vec<Attribute>, extra_methods: bool) -> TokenStream {
// TODO: SolInterface for events
self.generate_enum(attrs)
self.generate_enum(attrs, extra_methods)
}

fn generate_enum(&self, mut attrs: Vec<Attribute>) -> TokenStream {
fn generate_enum(&self, mut attrs: Vec<Attribute>, extra_methods: bool) -> TokenStream {
let Self {
name,
variants,
Expand All @@ -330,31 +346,40 @@ impl<'a> CallLikeExpander<'a> {
types.iter().cloned().map(ast::Type::custom),
false,
);

let conversions = variants
.iter()
.zip(types)
.map(|(v, t)| generate_variant_conversions(name, v, t));
let methods = variants.iter().zip(types).map(generate_variant_methods);

quote! {
let tokens = quote! {
#(#attrs)*
pub enum #name {
#(#variants(#types),)*
}

#(#conversions)*

#[automatically_derived]
impl #name {
/// All the selectors of this enum.
///
/// Note that the selectors might not be in the same order as the
/// variants, as they are sorted instead of ordered by definition.
pub const SELECTORS: &'static [#selector_type] = &[#selectors];
}
};

#(#methods)*
if extra_methods {
let conversions = variants
.iter()
.zip(types)
.map(|(v, t)| generate_variant_conversions(name, v, t));
let methods = variants.iter().zip(types).map(generate_variant_methods);
quote! {
#tokens

#(#conversions)*

#[automatically_derived]
impl #name {
#(#methods)*
}
}
} else {
tokens
}
}
}
Expand Down Expand Up @@ -430,12 +455,17 @@ fn generate_variant_methods((variant, ty): (&Ident, &Ident)) -> TokenStream {

/// `heck` doesn't treat numbers as new words, and discards leading underscores.
fn snakify(s: &str) -> String {
let leading = s.chars().take_while(|c| *c == '_');
let mut output: Vec<char> = leading.chain(s.to_snake_case().chars()).collect();
let leading_n = s.chars().take_while(|c| *c == '_').count();
let (leading, s) = s.split_at(leading_n);
let mut output: Vec<char> = leading.chars().chain(s.to_snake_case().chars()).collect();

let mut num_starts = vec![];
for (pos, c) in output.iter().enumerate() {
if pos != 0 && c.is_ascii_digit() && !output[pos - 1].is_ascii_digit() {
if pos != 0
&& c.is_ascii_digit()
&& !output[pos - 1].is_ascii_digit()
&& !output[pos - 1].is_ascii_punctuation()
{
num_starts.push(pos);
}
}
Expand Down
29 changes: 9 additions & 20 deletions crates/sol-macro/src/expand/function.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//! [`ItemFunction`] expansion.

use super::{
expand_fields, expand_from_into_tuples, expand_from_into_unit, expand_tuple_types,
ty::expand_tokenize_func, ExpCtxt,
expand_fields, expand_from_into_tuples, expand_tuple_types, ty::expand_tokenize_func, ExpCtxt,
};
use ast::ItemFunction;
use proc_macro2::TokenStream;
Expand Down Expand Up @@ -37,41 +36,31 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, function: &ItemFunction) -> Result<TokenS
// ignore functions without names (constructors, modifiers...)
return Ok(quote!())
};
let returns = returns.as_ref().map(|r| &r.returns).unwrap_or_default();

cx.assert_resolved(arguments)?;
if let Some(returns) = returns {
cx.assert_resolved(&returns.returns)?;
if !returns.is_empty() {
cx.assert_resolved(returns)?;
}

let (_sol_attrs, mut call_attrs) = crate::attr::SolAttrs::parse(attrs)?;
let mut return_attrs = call_attrs.clone();
cx.derives(&mut call_attrs, arguments, true);
if let Some(returns) = returns {
cx.derives(&mut return_attrs, &returns.returns, true);
if !returns.is_empty() {
cx.derives(&mut return_attrs, returns, true);
}

let call_name = cx.call_name(function);
let return_name = cx.return_name(function);

let call_fields = expand_fields(arguments);
let return_fields = if let Some(returns) = returns {
expand_fields(&returns.returns).collect::<Vec<_>>()
} else {
vec![]
};
let return_fields = expand_fields(returns);

let call_tuple = expand_tuple_types(arguments.types()).0;
let return_tuple = if let Some(returns) = returns {
expand_tuple_types(returns.returns.types()).0
} else {
quote! { () }
};
let return_tuple = expand_tuple_types(returns.types()).0;

let converts = expand_from_into_tuples(&call_name, arguments);
let return_converts = returns
.as_ref()
.map(|returns| expand_from_into_tuples(&return_name, &returns.returns))
.unwrap_or_else(|| expand_from_into_unit(&return_name));
let return_converts = expand_from_into_tuples(&return_name, returns);

let signature = cx.function_signature(function);
let selector = crate::utils::selector(&signature);
Expand Down
Loading