Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions crates/sol-macro-expander/src/expand/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,7 @@ pub(super) fn expand(cx: &mut ExpCtxt<'_>, contract: &ItemContract) -> Result<To
item_tokens.extend(cx.expand_item(&item)?);
}
}
cx.attrs = prev_cx_attrs;

let enum_expander = CallLikeExpander { cx, contract_name: name.clone(), extra_methods };
// Remove any `Default` derives.
let mut enum_attrs = item_attrs;
for attr in &mut enum_attrs {
Expand All @@ -166,14 +164,7 @@ pub(super) fn expand(cx: &mut ExpCtxt<'_>, contract: &ItemContract) -> Result<To
attr.meta = parse_quote! { derive(#(#derives),*) };
}

let functions_enum = (!functions.is_empty()).then(|| {
let mut attrs = enum_attrs.clone();
let doc_str = format!("Container for all the [`{name}`](self) function calls.");
attrs.push(parse_quote!(#[doc = #doc_str]));
attrs.push(parse_quote!(#[derive(Clone)]));
enum_expander.expand(ToExpand::Functions(&functions), attrs)
});

let enum_expander = CallLikeExpander { cx, contract_name: name.clone(), extra_methods };
let errors_enum = (!errors.is_empty()).then(|| {
let mut attrs = enum_attrs.clone();
let doc_str = format!("Container for all the [`{name}`](self) custom errors.");
Expand All @@ -183,13 +174,25 @@ pub(super) fn expand(cx: &mut ExpCtxt<'_>, contract: &ItemContract) -> Result<To
});

let events_enum = (!events.is_empty()).then(|| {
let mut attrs = enum_attrs;
let mut attrs = enum_attrs.clone();
let doc_str = format!("Container for all the [`{name}`](self) events.");
attrs.push(parse_quote!(#[doc = #doc_str]));
attrs.push(parse_quote!(#[derive(Clone)]));
enum_expander.expand(ToExpand::Events(&events), attrs)
});

// Do not propagate contract-level derives to the functions enum.
cx.attrs = prev_cx_attrs;

let functions_enum = (!functions.is_empty()).then(|| {
let mut attrs = enum_attrs;
let doc_str = format!("Container for all the [`{name}`](self) function calls.");
attrs.push(parse_quote!(#[doc = #doc_str]));
attrs.push(parse_quote!(#[derive(Clone)]));
let enum_expander = CallLikeExpander { cx, contract_name: name.clone(), extra_methods };
enum_expander.expand(ToExpand::Functions(&functions), attrs)
});
Comment on lines +187 to +194
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this just moves the location and add thes derives

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it uses the flushed cx.


let mod_descr_doc = (docs && docs_str(&mod_attrs).trim().is_empty())
.then(|| mk_doc("Module containing a contract's types and functions."));
let mod_iface_doc = (docs && !docs_str(&mod_attrs).contains("```solidity\n"))
Expand Down
88 changes: 88 additions & 0 deletions crates/sol-types/tests/derives.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
//! Comprehensive test for contract-level derives applying to all generated types.
//! Tests all_derives and extra_derives on contracts with events and errors.

use alloy_primitives::{Address, U256};
use alloy_sol_types::sol;
use std::{collections::HashSet, hash::Hash};

#[test]
fn test_all_derives() {
sol! {
#[sol(all_derives)]
contract AllDerivesContract {
function transfer(address to, uint256 amount) external returns (bool);
event Transfer(address indexed from, address indexed to, uint256 value);
error InsufficientBalance(uint256 requested, uint256 available);
}
}

use AllDerivesContract::*;

let event1 = Transfer { from: Address::ZERO, to: Address::ZERO, value: U256::from(50) };
let event2 = Transfer { from: Address::ZERO, to: Address::ZERO, value: U256::from(50) };
let events_enum1 = AllDerivesContractEvents::Transfer(event1);
let events_enum2 = AllDerivesContractEvents::Transfer(event2);

let error1 = InsufficientBalance { requested: U256::from(100), available: U256::from(50) };
let error2 = InsufficientBalance { requested: U256::from(100), available: U256::from(50) };
let errors_enum1 = AllDerivesContractErrors::InsufficientBalance(error1);
let errors_enum2 = AllDerivesContractErrors::InsufficientBalance(error2);

// Test PartialEq and Debug
assert_eq!(errors_enum1, errors_enum2);
assert_eq!(events_enum1, events_enum2);

// Test Hash and Eq derives
let mut events_set = HashSet::new();
events_set.insert(events_enum1);
events_set.insert(events_enum2);
// Should not increase size since they're equal
assert_eq!(events_set.len(), 1);

let mut errors_set = HashSet::new();
errors_set.insert(errors_enum1);
errors_set.insert(errors_enum2);
// Should not increase size since they're equal
assert_eq!(errors_set.len(), 1);
}

#[test]
fn test_extra_derives() {
sol! {
#[sol(extra_derives(PartialEq, Eq, Hash, Debug))]
contract ExtraDerivesContract {
function transfer(address to, uint256 amount) external returns (bool);
event Transfer(address indexed from, address indexed to, uint256 value);
error InsufficientBalance(uint256 requested, uint256 available);
}
}

use ExtraDerivesContract::*;

let event1 = Transfer { from: Address::ZERO, to: Address::ZERO, value: U256::from(50) };
let event2 = Transfer { from: Address::ZERO, to: Address::ZERO, value: U256::from(50) };
let events_enum1 = ExtraDerivesContractEvents::Transfer(event1);
let events_enum2 = ExtraDerivesContractEvents::Transfer(event2);

let error1 = InsufficientBalance { requested: U256::from(100), available: U256::from(50) };
let error2 = InsufficientBalance { requested: U256::from(100), available: U256::from(50) };
let errors_enum1 = ExtraDerivesContractErrors::InsufficientBalance(error1);
let errors_enum2 = ExtraDerivesContractErrors::InsufficientBalance(error2);

// Test PartialEq and Debug
assert_eq!(errors_enum1, errors_enum2);
assert_eq!(events_enum1, events_enum2);

// Test Hash and Eq derives
let mut events_set = HashSet::new();
events_set.insert(events_enum1);
events_set.insert(events_enum2);
// Should not increase size since they're equal
assert_eq!(events_set.len(), 1);

let mut errors_set = HashSet::new();
errors_set.insert(errors_enum1);
errors_set.insert(errors_enum2);
// Should not increase size since they're equal
assert_eq!(errors_set.len(), 1);
}
Loading