Skip to content

Commit 72d0c81

Browse files
authored
feat: sol! contracts (#77)
* feat: `sol!` contracts * chore: clippy * docs: add note about parser design, leniency * test: bless tests after updating to syn 2.0.19 * typo * test: bless new test
1 parent 0a00f01 commit 72d0c81

File tree

19 files changed

+829
-130
lines changed

19 files changed

+829
-130
lines changed

crates/sol-macro/src/expand/attr.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
use syn::Attribute;
2+
3+
pub(crate) fn docs(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
4+
attrs.iter().filter(|attr| attr.path().is_ident("doc"))
5+
}
6+
7+
pub(crate) fn derives(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
8+
attrs.iter().filter(|attr| attr.path().is_ident("derive"))
9+
}

crates/sol-macro/src/expand/mod.rs

Lines changed: 183 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
//! Functions which generate Rust code from the Solidity AST.
22
33
use ast::{
4-
File, Item, ItemError, ItemFunction, ItemStruct, ItemUdt, Parameters, SolIdent, Type,
5-
VariableDeclaration, Visit,
4+
File, Item, ItemContract, ItemError, ItemFunction, ItemStruct, ItemUdt, Parameters, SolIdent,
5+
Type, VariableDeclaration, Visit,
66
};
77
use proc_macro2::{Ident, Span, TokenStream};
88
use quote::{format_ident, quote, IdentFragment};
99
use std::{collections::HashMap, fmt::Write};
10-
use syn::{Error, Result, Token};
10+
use syn::{parse_quote, Attribute, Error, Result, Token};
11+
12+
mod attr;
1113

1214
mod r#type;
1315
pub use r#type::expand_type;
@@ -76,13 +78,180 @@ impl<'ast> ExpCtxt<'ast> {
7678

7779
fn expand_item(&self, item: &Item) -> Result<TokenStream> {
7880
match item {
81+
Item::Contract(contract) => self.expand_contract(contract),
7982
Item::Error(error) => self.expand_error(error),
8083
Item::Function(function) => self.expand_function(function),
8184
Item::Struct(s) => self.expand_struct(s),
8285
Item::Udt(udt) => self.expand_udt(udt),
8386
}
8487
}
8588

89+
fn expand_contract(&self, contract: &ItemContract) -> Result<TokenStream> {
90+
let ItemContract {
91+
attrs, name, body, ..
92+
} = contract;
93+
94+
let mut functions = Vec::with_capacity(contract.body.len());
95+
let mut errors = Vec::with_capacity(contract.body.len());
96+
let mut item_tokens = TokenStream::new();
97+
let d_attrs: Vec<Attribute> = attr::derives(attrs).cloned().collect();
98+
for item in body {
99+
match item {
100+
Item::Function(function) => functions.push(function),
101+
Item::Error(error) => errors.push(error),
102+
_ => {}
103+
}
104+
item_tokens.extend(quote!(#(#d_attrs)*));
105+
item_tokens.extend(self.expand_item(item)?);
106+
}
107+
108+
let functions_enum = if functions.len() > 1 {
109+
let mut attrs = d_attrs.clone();
110+
let doc_str = format!("Container for all the [`{name}`] function calls.");
111+
attrs.push(parse_quote!(#[doc = #doc_str]));
112+
Some(self.expand_functions_enum(name, functions, &attrs))
113+
} else {
114+
None
115+
};
116+
117+
let errors_enum = if errors.len() > 1 {
118+
let mut attrs = d_attrs;
119+
let doc_str = format!("Container for all the [`{name}`] custom errors.");
120+
attrs.push(parse_quote!(#[doc = #doc_str]));
121+
Some(self.expand_errors_enum(name, errors, &attrs))
122+
} else {
123+
None
124+
};
125+
126+
let mod_attrs = attr::docs(attrs);
127+
let tokens = quote! {
128+
#(#mod_attrs)*
129+
#[allow(non_camel_case_types, non_snake_case, clippy::style)]
130+
pub mod #name {
131+
#item_tokens
132+
#functions_enum
133+
#errors_enum
134+
}
135+
};
136+
Ok(tokens)
137+
}
138+
139+
fn expand_functions_enum(
140+
&self,
141+
name: &SolIdent,
142+
functions: Vec<&ItemFunction>,
143+
attrs: &[Attribute],
144+
) -> TokenStream {
145+
let name = format_ident!("{name}Calls");
146+
let variants: Vec<_> = functions
147+
.iter()
148+
.map(|f| self.function_name_ident(f).0)
149+
.collect();
150+
let types: Vec<_> = variants.iter().map(|name| self.call_name(name)).collect();
151+
let min_data_len = functions
152+
.iter()
153+
.map(|function| self.min_data_size(&function.arguments))
154+
.max()
155+
.unwrap();
156+
let trt = Ident::new("SolCall", Span::call_site());
157+
self.expand_call_like_enum(name, &variants, &types, min_data_len, trt, attrs)
158+
}
159+
160+
fn expand_errors_enum(
161+
&self,
162+
name: &SolIdent,
163+
errors: Vec<&ItemError>,
164+
attrs: &[Attribute],
165+
) -> TokenStream {
166+
let name = format_ident!("{name}Errors");
167+
let variants: Vec<_> = errors.iter().map(|error| error.name.0.clone()).collect();
168+
let min_data_len = errors
169+
.iter()
170+
.map(|error| self.min_data_size(&error.fields))
171+
.max()
172+
.unwrap();
173+
let trt = Ident::new("SolError", Span::call_site());
174+
self.expand_call_like_enum(name, &variants, &variants, min_data_len, trt, attrs)
175+
}
176+
177+
fn expand_call_like_enum(
178+
&self,
179+
name: Ident,
180+
variants: &[Ident],
181+
types: &[Ident],
182+
min_data_len: usize,
183+
trt: Ident,
184+
attrs: &[Attribute],
185+
) -> TokenStream {
186+
assert_eq!(variants.len(), types.len());
187+
let name_s = name.to_string();
188+
let count = variants.len();
189+
let min_data_len = min_data_len.min(4);
190+
quote! {
191+
#(#attrs)*
192+
pub enum #name {#(
193+
#variants(#types),
194+
)*}
195+
196+
// TODO: Implement these functions using traits?
197+
#[automatically_derived]
198+
impl #name {
199+
/// The number of variants.
200+
pub const COUNT: usize = #count;
201+
202+
// no decode_raw is possible because we need the selector to know which variant to
203+
// decode into
204+
205+
/// ABI-decodes the given data into one of the variants of `self`.
206+
pub fn decode(data: &[u8], validate: bool) -> ::alloy_sol_types::Result<Self> {
207+
if data.len() >= #min_data_len {
208+
// TODO: Replace with `data.split_array_ref` once it's stable
209+
let (selector, data) = data.split_at(4);
210+
let selector: &[u8; 4] =
211+
::core::convert::TryInto::try_into(selector).expect("unreachable");
212+
match *selector {
213+
#(<#types as ::alloy_sol_types::#trt>::SELECTOR => {
214+
return <#types as ::alloy_sol_types::#trt>::decode_raw(data, validate)
215+
.map(Self::#variants)
216+
})*
217+
_ => {}
218+
}
219+
}
220+
::core::result::Result::Err(::alloy_sol_types::Error::type_check_fail(
221+
data,
222+
#name_s,
223+
))
224+
}
225+
226+
/// ABI-encodes `self` into the given buffer.
227+
pub fn encode_raw(&self, out: &mut Vec<u8>) {
228+
match self {#(
229+
Self::#variants(inner) =>
230+
<#types as ::alloy_sol_types::#trt>::encode_raw(inner, out),
231+
)*}
232+
}
233+
234+
/// ABI-encodes `self` into the given buffer.
235+
#[inline]
236+
pub fn encode(&self) -> Vec<u8> {
237+
match self {#(
238+
Self::#variants(inner) =>
239+
<#types as ::alloy_sol_types::#trt>::encode(inner),
240+
)*}
241+
}
242+
}
243+
244+
#(
245+
#[automatically_derived]
246+
impl From<#types> for #name {
247+
fn from(value: #types) -> Self {
248+
Self::#variants(value)
249+
}
250+
}
251+
)*
252+
}
253+
}
254+
86255
fn expand_error(&self, error: &ItemError) -> Result<TokenStream> {
87256
let ItemError {
88257
fields,
@@ -344,7 +513,7 @@ impl<'ast> ExpCtxt<'ast> {
344513
ty.visit_mut(|ty| {
345514
let ty @ Type::Custom(_) = ty else { return };
346515
let Type::Custom(name) = &*ty else { unreachable!() };
347-
let Some(resolved) = self.custom_types.get(name) else { return };
516+
let Some(resolved) = self.custom_types.get(name.last_tmp()) else { return };
348517
ty.clone_from(resolved);
349518
any = true;
350519
});
@@ -431,6 +600,15 @@ impl<'ast> ExpCtxt<'ast> {
431600
}
432601
}
433602

603+
/// Returns the name of the function, adjusted for overloads.
604+
fn function_name_ident(&self, function: &ItemFunction) -> SolIdent {
605+
let sig = self.function_signature(function);
606+
match self.function_overloads.get(&sig) {
607+
Some(name) => SolIdent::new_spanned(name, function.name.span()),
608+
None => function.name.clone(),
609+
}
610+
}
611+
434612
fn call_name(&self, function_name: impl IdentFragment + std::fmt::Display) -> Ident {
435613
format_ident!("{function_name}Call")
436614
}
@@ -464,7 +642,7 @@ impl<'ast> ExpCtxt<'ast> {
464642
let mut errors = Vec::new();
465643
params.visit_types(|ty| {
466644
if let Type::Custom(name) = ty {
467-
if !self.custom_types.contains_key(name) {
645+
if !self.custom_types.contains_key(name.last_tmp()) {
468646
let e = syn::Error::new(name.span(), "unresolved type");
469647
errors.push(e);
470648
}

crates/sol-macro/src/expand/type.rs

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::ExpCtxt;
2-
use ast::{Item, SolArray, SolIdent, Type};
2+
use ast::{Item, Parameters, SolArray, SolPath, Type};
33
use proc_macro2::{Literal, TokenStream};
44
use quote::{quote, quote_spanned, ToTokens};
55
use std::{fmt, num::NonZeroU16};
@@ -74,20 +74,29 @@ fn rec_expand_type(ty: &Type, tokens: &mut TokenStream) {
7474
}
7575

7676
impl ExpCtxt<'_> {
77-
fn get_item(&self, name: &SolIdent) -> &Item {
77+
fn get_item(&self, name: &SolPath) -> &Item {
78+
let name = name.last_tmp();
7879
match self.all_items.iter().find(|item| item.name() == name) {
7980
Some(item) => item,
8081
None => panic!("unresolved item: {name}"),
8182
}
8283
}
8384

84-
fn custom_type(&self, name: &SolIdent) -> &Type {
85-
match self.custom_types.get(name) {
85+
fn custom_type(&self, name: &SolPath) -> &Type {
86+
match self.custom_types.get(name.last_tmp()) {
8687
Some(item) => item,
8788
None => panic!("unresolved item: {name}"),
8889
}
8990
}
9091

92+
pub(super) fn min_data_size<P>(&self, params: &Parameters<P>) -> usize {
93+
params
94+
.iter()
95+
.map(|param| self.type_base_data_size(&param.ty))
96+
.max()
97+
.unwrap_or(0)
98+
}
99+
91100
/// Recursively calculates the base ABI-encoded size of `self` in bytes.
92101
///
93102
/// That is, the minimum number of bytes required to encode `self` without
@@ -127,7 +136,7 @@ impl ExpCtxt<'_> {
127136
.map(|ty| self.type_base_data_size(ty))
128137
.sum(),
129138
Item::Udt(udt) => self.type_base_data_size(&udt.ty),
130-
Item::Error(_) | Item::Function(_) => unreachable!(),
139+
Item::Contract(_) | Item::Error(_) | Item::Function(_) => unreachable!(),
131140
},
132141
}
133142
}
@@ -182,7 +191,7 @@ impl ExpCtxt<'_> {
182191
Type::Custom(name) => match self.get_item(name) {
183192
Item::Struct(strukt) => self.params_data_size(&strukt.fields, Some(field)),
184193
Item::Udt(udt) => self.type_data_size(&udt.ty, field),
185-
Item::Error(_) | Item::Function(_) => unreachable!(),
194+
Item::Contract(_) | Item::Error(_) | Item::Function(_) => unreachable!(),
186195
},
187196
}
188197
}

crates/sol-types/src/types/call.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ pub trait SolCall: Sized {
3535
/// selector.
3636
#[inline]
3737
fn decode_raw(data: &[u8], validate: bool) -> Result<Self> {
38-
let tuple = <Self::Tuple as SolType>::decode(data, validate)?;
39-
Ok(Self::from_rust(tuple))
38+
<Self::Tuple as SolType>::decode(data, validate).map(Self::from_rust)
4039
}
4140

4241
/// ABI decode this call's arguments from the given slice, **with** the
@@ -52,6 +51,7 @@ pub trait SolCall: Sized {
5251
/// ABI encode the call to the given buffer **without** its selector.
5352
#[inline]
5453
fn encode_raw(&self, out: &mut Vec<u8>) {
54+
out.reserve(self.data_size());
5555
out.extend(<Self::Tuple as SolType>::encode(self.to_rust()));
5656
}
5757

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,42 @@
1+
use alloy_primitives::{Address, U256};
2+
use alloy_sol_types::{sol, SolCall};
3+
use hex_literal::hex;
4+
5+
sol! {
6+
/// Interface of the ERC20 standard as defined in [the EIP].
7+
///
8+
/// [the EIP]: https://eips.ethereum.org/EIPS/eip-20
9+
#[derive(Debug, PartialEq)]
10+
interface IERC20 {
11+
// TODO: Events
12+
// event Transfer(address indexed from, address indexed to, uint256 value);
13+
// event Approval(address indexed owner, address indexed spender, uint256 value);
14+
15+
function totalSupply() external view returns (uint256);
16+
function balanceOf(address account) external view returns (uint256);
17+
function transfer(address to, uint256 amount) external returns (bool);
18+
function allowance(address owner, address spender) external view returns (uint256);
19+
function approve(address spender, uint256 amount) external returns (bool);
20+
function transferFrom(address from, address to, uint256 amount) external returns (bool);
21+
}
22+
}
23+
124
#[test]
225
fn contracts() {
3-
// TODO
26+
// random mainnet ERC20 transfer
27+
// https://etherscan.io/tx/0x947332ff624b5092fb92e8f02cdbb8a50314e861a4b39c29a286b3b75432165e
28+
let data = hex!(
29+
"a9059cbb"
30+
"0000000000000000000000008bc47be1e3abbaba182069c89d08a61fa6c2b292"
31+
"0000000000000000000000000000000000000000000000000000000253c51700"
32+
);
33+
let expected = IERC20::transferCall {
34+
to: Address::from(hex!("8bc47be1e3abbaba182069c89d08a61fa6c2b292")),
35+
amount: U256::from(9995360000_u64),
36+
};
37+
38+
assert_eq!(data[..4], IERC20::transferCall::SELECTOR);
39+
let decoded = IERC20::IERC20Calls::decode(&data, true).unwrap();
40+
assert_eq!(decoded, IERC20::IERC20Calls::transfer(expected));
41+
assert_eq!(decoded.encode(), data);
442
}

crates/sol-types/tests/ui/contract.rs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
use alloy_sol_types::sol;
2+
3+
sol! {
4+
contract C {
5+
contract Nested {}
6+
}
7+
}
8+
9+
sol! {
10+
interface C {
11+
library Nested {}
12+
}
13+
}
14+
15+
sol! {
16+
abstract contract C {
17+
interface Nested {}
18+
}
19+
}
20+
21+
fn main() {}

0 commit comments

Comments
 (0)