Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(sol-macro): add support for overloaded events #318

Merged
merged 1 commit into from
Sep 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions crates/sol-macro/src/expand/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ impl<'a> CallLikeExpander<'a> {
) -> Self {
let variants: Vec<_> = functions
.iter()
.map(|f| cx.function_name_ident(f).0)
.map(|&f| cx.overloaded_name(f.into()).0)
.collect();

let types: Vec<_> = variants.iter().map(|name| cx.raw_call_name(name)).collect();
Expand Down Expand Up @@ -211,13 +211,18 @@ impl<'a> CallLikeExpander<'a> {
}

fn from_events(cx: &'a ExpCtxt<'a>, contract_name: &SolIdent, events: Vec<&ItemEvent>) -> Self {
let variants: Vec<_> = events
.iter()
.map(|&event| cx.overloaded_name(event.into()).0)
.collect();

let mut selectors: Vec<_> = events.iter().map(|e| cx.event_selector(e)).collect();
selectors.sort_unstable_by_key(|a| a.array);

Self {
cx,
name: format_ident!("{contract_name}Events"),
variants: events.iter().map(|event| event.name.0.clone()).collect(),
variants,
min_data_len: events
.iter()
.map(|event| ty::params_base_data_size(cx, &event.params()))
Expand All @@ -228,8 +233,8 @@ impl<'a> CallLikeExpander<'a> {
}
}

/// Type name overrides. Currently only functions support this through
/// overloading.
/// Type name overrides. Currently only functions support because of the
/// `Call` suffix.
fn types(&self) -> &[Ident] {
match &self.data {
CallLikeExpanderData::Function { types, .. } => types,
Expand Down
3 changes: 2 additions & 1 deletion crates/sol-macro/src/expand/event.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use syn::Result;
/// }
/// ```
pub(super) fn expand(cx: &ExpCtxt<'_>, event: &ItemEvent) -> Result<TokenStream> {
let ItemEvent { name, attrs, .. } = event;
let ItemEvent { attrs, .. } = event;
let params = event.params();

let (_sol_attrs, mut attrs) = crate::attr::SolAttrs::parse(attrs)?;
Expand All @@ -27,6 +27,7 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, event: &ItemEvent) -> Result<TokenStream>
cx.assert_resolved(&params)?;
event.assert_valid()?;

let name = cx.overloaded_name(event.into());
let signature = cx.signature(name.as_string(), &params);
let selector = crate::utils::event_selector(&signature);
let anonymous = event.is_anonymous();
Expand Down
189 changes: 131 additions & 58 deletions crates/sol-macro/src/expand/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ struct ExpCtxt<'ast> {
all_items: Vec<&'ast Item>,
custom_types: HashMap<SolIdent, Type>,

/// `name => functions`
functions: HashMap<String, Vec<&'ast ItemFunction>>,
/// `function_signature => new_name`
function_overloads: HashMap<String, String>,
/// `name => item`
overloaded_items: HashMap<String, Vec<OverloadedItem<'ast>>>,
/// `signature => new_name`
overloads: HashMap<String, String>,

attrs: SolAttrs,
ast: &'ast File,
Expand All @@ -53,8 +53,8 @@ impl<'ast> ExpCtxt<'ast> {
Self {
all_items: Vec::new(),
custom_types: HashMap::new(),
functions: HashMap::new(),
function_overloads: HashMap::new(),
overloaded_items: HashMap::new(),
overloads: HashMap::new(),
attrs: SolAttrs::default(),
ast,
}
Expand Down Expand Up @@ -102,7 +102,7 @@ impl<'ast> ExpCtxt<'ast> {
}

// resolve
impl ExpCtxt<'_> {
impl<'ast> ExpCtxt<'ast> {
fn parse_file_attributes(&mut self) -> Result<()> {
let (attrs, others) = attr::SolAttrs::parse(&self.ast.attrs)?;
self.attrs = attrs;
Expand Down Expand Up @@ -165,23 +165,26 @@ impl ExpCtxt<'_> {
}

fn mk_overloads_map(&mut self) -> Result<()> {
let all_orig_names: Vec<SolIdent> = self
.functions
let all_orig_names: Vec<_> = self
.overloaded_items
.values()
.flatten()
.filter_map(|f| f.name.clone())
.filter_map(|f| f.name())
.collect();
let mut overloads_map = std::mem::take(&mut self.function_overloads);
let mut overloads_map = std::mem::take(&mut self.overloads);

// report all errors at the end
let mut errors = Vec::new();

for functions in self.functions.values().filter(|fs| fs.len() >= 2) {
for functions in self.overloaded_items.values().filter(|fs| fs.len() >= 2) {
// check for same parameters
for (i, &a) in functions.iter().enumerate() {
for &b in functions.iter().skip(i + 1) {
if a.arguments.types().eq(b.arguments.types()) {
let msg = "function with same name and parameter types defined twice";
if a.eq_by_types(b) {
let msg = format!(
"{} with same name and parameter types defined twice",
a.desc()
);
let mut err = syn::Error::new(a.span(), msg);

let msg = "other declaration is here";
Expand All @@ -193,15 +196,16 @@ impl ExpCtxt<'_> {
}
}

for (i, &function) in functions.iter().enumerate() {
let Some(old_name) = function.name.as_ref() else {
for (i, &item) in functions.iter().enumerate() {
let Some(old_name) = item.name() else {
continue
};
let new_name = format!("{old_name}_{i}");
if let Some(other) = all_orig_names.iter().find(|x| x.0 == new_name) {
let msg = format!(
"function `{old_name}` is overloaded, \
but the generated name `{new_name}` is already in use"
"{} `{old_name}` is overloaded, \
but the generated name `{new_name}` is already in use",
item.desc()
);
let mut err = syn::Error::new(old_name.span(), msg);

Expand All @@ -212,12 +216,12 @@ impl ExpCtxt<'_> {
errors.push(err);
}

overloads_map.insert(self.function_signature(function), new_name);
overloads_map.insert(item.signature(self), new_name);
}
}

utils::combine_errors(errors).map(|()| {
self.function_overloads = overloads_map;
self.overloads = overloads_map;
})
}
}
Expand All @@ -230,17 +234,81 @@ impl<'ast> Visit<'ast> for ExpCtxt<'ast> {

fn visit_item_function(&mut self, function: &'ast ItemFunction) {
if let Some(name) = &function.name {
self.functions
self.overloaded_items
.entry(name.as_string())
.or_default()
.push(function);
.push(OverloadedItem::Function(function));
}
ast::visit::visit_item_function(self, function);
}

fn visit_item_event(&mut self, event: &'ast ItemEvent) {
self.overloaded_items
.entry(event.name.as_string())
.or_default()
.push(OverloadedItem::Event(event));
ast::visit::visit_item_event(self, event);
}
}

#[derive(Clone, Copy)]
enum OverloadedItem<'a> {
Function(&'a ItemFunction),
Event(&'a ItemEvent),
}

impl<'ast> From<&'ast ItemFunction> for OverloadedItem<'ast> {
fn from(f: &'ast ItemFunction) -> Self {
Self::Function(f)
}
}

impl<'ast> From<&'ast ItemEvent> for OverloadedItem<'ast> {
fn from(e: &'ast ItemEvent) -> Self {
Self::Event(e)
}
}

impl<'a> OverloadedItem<'a> {
fn name(self) -> Option<&'a SolIdent> {
match self {
Self::Function(f) => f.name.as_ref(),
Self::Event(e) => Some(&e.name),
}
}

fn desc(&self) -> &'static str {
match self {
Self::Function(_) => "function",
Self::Event(_) => "event",
}
}

fn eq_by_types(self, other: Self) -> bool {
match (self, other) {
(Self::Function(a), Self::Function(b)) => a.arguments.types().eq(b.arguments.types()),
(Self::Event(a), Self::Event(b)) => a.param_types().eq(b.param_types()),
_ => false,
}
}

fn span(self) -> Span {
match self {
Self::Function(f) => f.span(),
Self::Event(e) => e.span(),
}
}

fn signature(self, cx: &ExpCtxt<'a>) -> String {
match self {
Self::Function(f) => cx.function_signature(f),
Self::Event(e) => cx.event_signature(e),
}
}
}

// utils
impl ExpCtxt<'_> {
impl<'ast> ExpCtxt<'ast> {
#[allow(dead_code)]
fn get_item(&self, name: &SolPath) -> &Item {
match self.try_get_item(name) {
Expand All @@ -266,59 +334,44 @@ impl ExpCtxt<'_> {

/// Returns the name of the function, adjusted for overloads.
fn function_name(&self, function: &ItemFunction) -> String {
let sig = self.function_signature(function);
match self.function_overloads.get(&sig) {
Some(name) => name.clone(),
None => function.name().as_string(),
}
self.overloaded_name(function.into()).as_string()
}

/// Returns the name of the function, adjusted for overloads.
fn function_name_ident(&self, function: &ItemFunction) -> SolIdent {
let sig = self.function_signature(function);
match self.function_overloads.get(&sig) {
Some(name) => SolIdent::new_spanned(name, function.name().span()),
None => function.name().clone(),
/// Returns the name of the given item, adjusted for overloads.
///
/// Use `.into()` to convert from `&ItemFunction` or `&ItemEvent`.
fn overloaded_name(&self, item: OverloadedItem<'ast>) -> SolIdent {
let original_ident = item.name().expect("item has no name");
let sig = item.signature(self);
match self.overloads.get(&sig) {
Some(name) => SolIdent::new_spanned(name, original_ident.span()),
None => original_ident.clone(),
}
}

fn raw_call_name(&self, function_name: impl quote::IdentFragment + std::fmt::Display) -> Ident {
format_ident!("{function_name}Call")
}

/// Returns the name of the function's call Rust struct.
fn call_name(&self, function: &ItemFunction) -> Ident {
let function_name = self.function_name(function);
self.raw_call_name(function_name)
}

fn raw_return_name(
&self,
function_name: impl quote::IdentFragment + std::fmt::Display,
) -> Ident {
format_ident!("{function_name}Return")
/// Formats the given name as a function's call Rust struct name.
fn raw_call_name(&self, function_name: impl quote::IdentFragment + std::fmt::Display) -> Ident {
format_ident!("{function_name}Call")
}

/// Returns the name of the function's return Rust struct.
fn return_name(&self, function: &ItemFunction) -> Ident {
let function_name = self.function_name(function);
self.raw_return_name(function_name)
}

fn signature<'a, I: IntoIterator<Item = &'a VariableDeclaration>>(
/// Formats the given name as a function's return Rust struct name.
fn raw_return_name(
&self,
mut name: String,
params: I,
) -> String {
name.push('(');
let mut first = true;
for param in params {
if !first {
name.push(',');
}
write!(name, "{}", ty::TypePrinter::new(self, &param.ty)).unwrap();
first = false;
}
name.push(')');
name
function_name: impl quote::IdentFragment + std::fmt::Display,
) -> Ident {
format_ident!("{function_name}Return")
}

fn function_signature(&self, function: &ItemFunction) -> String {
Expand Down Expand Up @@ -347,6 +400,25 @@ impl ExpCtxt<'_> {
utils::event_selector(self.event_signature(event))
}

/// Formats the name and parameters of the function as a Solidity signature.
fn signature<'a, I: IntoIterator<Item = &'a VariableDeclaration>>(
&self,
mut name: String,
params: I,
) -> String {
name.push('(');
let mut first = true;
for param in params {
if !first {
name.push(',');
}
write!(name, "{}", ty::TypePrinter::new(self, &param.ty)).unwrap();
first = false;
}
name.push(')');
name
}

/// Extends `attrs` with all possible derive attributes for the given type
/// if `#[sol(all_derives)]` was passed.
///
Expand Down Expand Up @@ -374,6 +446,7 @@ impl ExpCtxt<'_> {
self.type_derives(attrs, params.into_iter().map(|p| &p.ty), derive_default);
}

/// Implementation of [`derives`](Self::derives).
fn type_derives<T, I>(&self, attrs: &mut Vec<Attribute>, types: I, mut derive_default: bool)
where
I: IntoIterator<Item = T>,
Expand Down
33 changes: 0 additions & 33 deletions crates/sol-types/tests/ui/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,37 +72,4 @@ sol! {
function n() public pure returns (uint256,);
}

// OK
sol! {
function overloaded();
function overloaded(uint256);
function overloaded(uint256, address);
function overloaded(address);
function overloaded(address, string);
}

sol! {
function overloadTaken();
function overloadTaken(uint256);

function overloadTaken_0();
function overloadTaken_1();
function overloadTaken_2();
}

sol! {
function sameFnOverload();
function sameFnOverload();
}

sol! {
function sameFnTysOverload1(uint256[] memory a);
function sameFnTysOverload1(uint256[] storage b);
}

sol! {
function sameFnTysOverload2(string memory, string storage);
function sameFnTysOverload2(string storage b, string calldata);
}

fn main() {}
Loading