Skip to content

Commit

Permalink
feat: sol! contracts (#77)
Browse files Browse the repository at this point in the history
* feat: `sol!` contracts

* chore: clippy

* docs: add note about parser design, leniency

* test: bless tests after updating to syn 2.0.19

* typo

* test: bless new test
  • Loading branch information
DaniPopes authored Jun 8, 2023
1 parent 0a00f01 commit 72d0c81
Show file tree
Hide file tree
Showing 19 changed files with 829 additions and 130 deletions.
9 changes: 9 additions & 0 deletions crates/sol-macro/src/expand/attr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use syn::Attribute;

pub(crate) fn docs(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
attrs.iter().filter(|attr| attr.path().is_ident("doc"))
}

pub(crate) fn derives(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
attrs.iter().filter(|attr| attr.path().is_ident("derive"))
}
188 changes: 183 additions & 5 deletions crates/sol-macro/src/expand/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
//! Functions which generate Rust code from the Solidity AST.

use ast::{
File, Item, ItemError, ItemFunction, ItemStruct, ItemUdt, Parameters, SolIdent, Type,
VariableDeclaration, Visit,
File, Item, ItemContract, ItemError, ItemFunction, ItemStruct, ItemUdt, Parameters, SolIdent,
Type, VariableDeclaration, Visit,
};
use proc_macro2::{Ident, Span, TokenStream};
use quote::{format_ident, quote, IdentFragment};
use std::{collections::HashMap, fmt::Write};
use syn::{Error, Result, Token};
use syn::{parse_quote, Attribute, Error, Result, Token};

mod attr;

mod r#type;
pub use r#type::expand_type;
Expand Down Expand Up @@ -76,13 +78,180 @@ impl<'ast> ExpCtxt<'ast> {

fn expand_item(&self, item: &Item) -> Result<TokenStream> {
match item {
Item::Contract(contract) => self.expand_contract(contract),
Item::Error(error) => self.expand_error(error),
Item::Function(function) => self.expand_function(function),
Item::Struct(s) => self.expand_struct(s),
Item::Udt(udt) => self.expand_udt(udt),
}
}

fn expand_contract(&self, contract: &ItemContract) -> Result<TokenStream> {
let ItemContract {
attrs, name, body, ..
} = contract;

let mut functions = Vec::with_capacity(contract.body.len());
let mut errors = Vec::with_capacity(contract.body.len());
let mut item_tokens = TokenStream::new();
let d_attrs: Vec<Attribute> = attr::derives(attrs).cloned().collect();
for item in body {
match item {
Item::Function(function) => functions.push(function),
Item::Error(error) => errors.push(error),
_ => {}
}
item_tokens.extend(quote!(#(#d_attrs)*));
item_tokens.extend(self.expand_item(item)?);
}

let functions_enum = if functions.len() > 1 {
let mut attrs = d_attrs.clone();
let doc_str = format!("Container for all the [`{name}`] function calls.");
attrs.push(parse_quote!(#[doc = #doc_str]));
Some(self.expand_functions_enum(name, functions, &attrs))
} else {
None
};

let errors_enum = if errors.len() > 1 {
let mut attrs = d_attrs;
let doc_str = format!("Container for all the [`{name}`] custom errors.");
attrs.push(parse_quote!(#[doc = #doc_str]));
Some(self.expand_errors_enum(name, errors, &attrs))
} else {
None
};

let mod_attrs = attr::docs(attrs);
let tokens = quote! {
#(#mod_attrs)*
#[allow(non_camel_case_types, non_snake_case, clippy::style)]
pub mod #name {
#item_tokens
#functions_enum
#errors_enum
}
};
Ok(tokens)
}

fn expand_functions_enum(
&self,
name: &SolIdent,
functions: Vec<&ItemFunction>,
attrs: &[Attribute],
) -> TokenStream {
let name = format_ident!("{name}Calls");
let variants: Vec<_> = functions
.iter()
.map(|f| self.function_name_ident(f).0)
.collect();
let types: Vec<_> = variants.iter().map(|name| self.call_name(name)).collect();
let min_data_len = functions
.iter()
.map(|function| self.min_data_size(&function.arguments))
.max()
.unwrap();
let trt = Ident::new("SolCall", Span::call_site());
self.expand_call_like_enum(name, &variants, &types, min_data_len, trt, attrs)
}

fn expand_errors_enum(
&self,
name: &SolIdent,
errors: Vec<&ItemError>,
attrs: &[Attribute],
) -> TokenStream {
let name = format_ident!("{name}Errors");
let variants: Vec<_> = errors.iter().map(|error| error.name.0.clone()).collect();
let min_data_len = errors
.iter()
.map(|error| self.min_data_size(&error.fields))
.max()
.unwrap();
let trt = Ident::new("SolError", Span::call_site());
self.expand_call_like_enum(name, &variants, &variants, min_data_len, trt, attrs)
}

fn expand_call_like_enum(
&self,
name: Ident,
variants: &[Ident],
types: &[Ident],
min_data_len: usize,
trt: Ident,
attrs: &[Attribute],
) -> TokenStream {
assert_eq!(variants.len(), types.len());
let name_s = name.to_string();
let count = variants.len();
let min_data_len = min_data_len.min(4);
quote! {
#(#attrs)*
pub enum #name {#(
#variants(#types),
)*}

// TODO: Implement these functions using traits?
#[automatically_derived]
impl #name {
/// The number of variants.
pub const COUNT: usize = #count;

// no decode_raw is possible because we need the selector to know which variant to
// decode into

/// ABI-decodes the given data into one of the variants of `self`.
pub fn decode(data: &[u8], validate: bool) -> ::alloy_sol_types::Result<Self> {
if data.len() >= #min_data_len {
// TODO: Replace with `data.split_array_ref` once it's stable
let (selector, data) = data.split_at(4);
let selector: &[u8; 4] =
::core::convert::TryInto::try_into(selector).expect("unreachable");
match *selector {
#(<#types as ::alloy_sol_types::#trt>::SELECTOR => {
return <#types as ::alloy_sol_types::#trt>::decode_raw(data, validate)
.map(Self::#variants)
})*
_ => {}
}
}
::core::result::Result::Err(::alloy_sol_types::Error::type_check_fail(
data,
#name_s,
))
}

/// ABI-encodes `self` into the given buffer.
pub fn encode_raw(&self, out: &mut Vec<u8>) {
match self {#(
Self::#variants(inner) =>
<#types as ::alloy_sol_types::#trt>::encode_raw(inner, out),
)*}
}

/// ABI-encodes `self` into the given buffer.
#[inline]
pub fn encode(&self) -> Vec<u8> {
match self {#(
Self::#variants(inner) =>
<#types as ::alloy_sol_types::#trt>::encode(inner),
)*}
}
}

#(
#[automatically_derived]
impl From<#types> for #name {
fn from(value: #types) -> Self {
Self::#variants(value)
}
}
)*
}
}

fn expand_error(&self, error: &ItemError) -> Result<TokenStream> {
let ItemError {
fields,
Expand Down Expand Up @@ -344,7 +513,7 @@ impl<'ast> ExpCtxt<'ast> {
ty.visit_mut(|ty| {
let ty @ Type::Custom(_) = ty else { return };
let Type::Custom(name) = &*ty else { unreachable!() };
let Some(resolved) = self.custom_types.get(name) else { return };
let Some(resolved) = self.custom_types.get(name.last_tmp()) else { return };
ty.clone_from(resolved);
any = true;
});
Expand Down Expand Up @@ -431,6 +600,15 @@ impl<'ast> ExpCtxt<'ast> {
}
}

/// Returns the name of the function, adjusted for overloads.
fn function_name_ident(&self, function: &ItemFunction) -> SolIdent {
let sig = self.function_signature(function);
match self.function_overloads.get(&sig) {
Some(name) => SolIdent::new_spanned(name, function.name.span()),
None => function.name.clone(),
}
}

fn call_name(&self, function_name: impl IdentFragment + std::fmt::Display) -> Ident {
format_ident!("{function_name}Call")
}
Expand Down Expand Up @@ -464,7 +642,7 @@ impl<'ast> ExpCtxt<'ast> {
let mut errors = Vec::new();
params.visit_types(|ty| {
if let Type::Custom(name) = ty {
if !self.custom_types.contains_key(name) {
if !self.custom_types.contains_key(name.last_tmp()) {
let e = syn::Error::new(name.span(), "unresolved type");
errors.push(e);
}
Expand Down
21 changes: 15 additions & 6 deletions crates/sol-macro/src/expand/type.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::ExpCtxt;
use ast::{Item, SolArray, SolIdent, Type};
use ast::{Item, Parameters, SolArray, SolPath, Type};
use proc_macro2::{Literal, TokenStream};
use quote::{quote, quote_spanned, ToTokens};
use std::{fmt, num::NonZeroU16};
Expand Down Expand Up @@ -74,20 +74,29 @@ fn rec_expand_type(ty: &Type, tokens: &mut TokenStream) {
}

impl ExpCtxt<'_> {
fn get_item(&self, name: &SolIdent) -> &Item {
fn get_item(&self, name: &SolPath) -> &Item {
let name = name.last_tmp();
match self.all_items.iter().find(|item| item.name() == name) {
Some(item) => item,
None => panic!("unresolved item: {name}"),
}
}

fn custom_type(&self, name: &SolIdent) -> &Type {
match self.custom_types.get(name) {
fn custom_type(&self, name: &SolPath) -> &Type {
match self.custom_types.get(name.last_tmp()) {
Some(item) => item,
None => panic!("unresolved item: {name}"),
}
}

pub(super) fn min_data_size<P>(&self, params: &Parameters<P>) -> usize {
params
.iter()
.map(|param| self.type_base_data_size(&param.ty))
.max()
.unwrap_or(0)
}

/// Recursively calculates the base ABI-encoded size of `self` in bytes.
///
/// That is, the minimum number of bytes required to encode `self` without
Expand Down Expand Up @@ -127,7 +136,7 @@ impl ExpCtxt<'_> {
.map(|ty| self.type_base_data_size(ty))
.sum(),
Item::Udt(udt) => self.type_base_data_size(&udt.ty),
Item::Error(_) | Item::Function(_) => unreachable!(),
Item::Contract(_) | Item::Error(_) | Item::Function(_) => unreachable!(),
},
}
}
Expand Down Expand Up @@ -182,7 +191,7 @@ impl ExpCtxt<'_> {
Type::Custom(name) => match self.get_item(name) {
Item::Struct(strukt) => self.params_data_size(&strukt.fields, Some(field)),
Item::Udt(udt) => self.type_data_size(&udt.ty, field),
Item::Error(_) | Item::Function(_) => unreachable!(),
Item::Contract(_) | Item::Error(_) | Item::Function(_) => unreachable!(),
},
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/sol-types/src/types/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ pub trait SolCall: Sized {
/// selector.
#[inline]
fn decode_raw(data: &[u8], validate: bool) -> Result<Self> {
let tuple = <Self::Tuple as SolType>::decode(data, validate)?;
Ok(Self::from_rust(tuple))
<Self::Tuple as SolType>::decode(data, validate).map(Self::from_rust)
}

/// ABI decode this call's arguments from the given slice, **with** the
Expand All @@ -52,6 +51,7 @@ pub trait SolCall: Sized {
/// ABI encode the call to the given buffer **without** its selector.
#[inline]
fn encode_raw(&self, out: &mut Vec<u8>) {
out.reserve(self.data_size());
out.extend(<Self::Tuple as SolType>::encode(self.to_rust()));
}

Expand Down
40 changes: 39 additions & 1 deletion crates/sol-types/tests/doc_contracts.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,42 @@
use alloy_primitives::{Address, U256};
use alloy_sol_types::{sol, SolCall};
use hex_literal::hex;

sol! {
/// Interface of the ERC20 standard as defined in [the EIP].
///
/// [the EIP]: https://eips.ethereum.org/EIPS/eip-20
#[derive(Debug, PartialEq)]
interface IERC20 {
// TODO: Events
// event Transfer(address indexed from, address indexed to, uint256 value);
// event Approval(address indexed owner, address indexed spender, uint256 value);

function totalSupply() external view returns (uint256);
function balanceOf(address account) external view returns (uint256);
function transfer(address to, uint256 amount) external returns (bool);
function allowance(address owner, address spender) external view returns (uint256);
function approve(address spender, uint256 amount) external returns (bool);
function transferFrom(address from, address to, uint256 amount) external returns (bool);
}
}

#[test]
fn contracts() {
// TODO
// random mainnet ERC20 transfer
// https://etherscan.io/tx/0x947332ff624b5092fb92e8f02cdbb8a50314e861a4b39c29a286b3b75432165e
let data = hex!(
"a9059cbb"
"0000000000000000000000008bc47be1e3abbaba182069c89d08a61fa6c2b292"
"0000000000000000000000000000000000000000000000000000000253c51700"
);
let expected = IERC20::transferCall {
to: Address::from(hex!("8bc47be1e3abbaba182069c89d08a61fa6c2b292")),
amount: U256::from(9995360000_u64),
};

assert_eq!(data[..4], IERC20::transferCall::SELECTOR);
let decoded = IERC20::IERC20Calls::decode(&data, true).unwrap();
assert_eq!(decoded, IERC20::IERC20Calls::transfer(expected));
assert_eq!(decoded.encode(), data);
}
21 changes: 21 additions & 0 deletions crates/sol-types/tests/ui/contract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use alloy_sol_types::sol;

sol! {
contract C {
contract Nested {}
}
}

sol! {
interface C {
library Nested {}
}
}

sol! {
abstract contract C {
interface Nested {}
}
}

fn main() {}
Loading

0 comments on commit 72d0c81

Please sign in to comment.