Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: sol! contracts #77

Merged
merged 6 commits into from
Jun 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions crates/sol-macro/src/expand/attr.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
use syn::Attribute;

pub(crate) fn docs(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
attrs.iter().filter(|attr| attr.path().is_ident("doc"))
}

pub(crate) fn derives(attrs: &[Attribute]) -> impl Iterator<Item = &Attribute> {
attrs.iter().filter(|attr| attr.path().is_ident("derive"))
}
188 changes: 183 additions & 5 deletions crates/sol-macro/src/expand/mod.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
//! Functions which generate Rust code from the Solidity AST.

use ast::{
File, Item, ItemError, ItemFunction, ItemStruct, ItemUdt, Parameters, SolIdent, Type,
VariableDeclaration, Visit,
File, Item, ItemContract, ItemError, ItemFunction, ItemStruct, ItemUdt, Parameters, SolIdent,
Type, VariableDeclaration, Visit,
};
use proc_macro2::{Ident, Span, TokenStream};
use quote::{format_ident, quote, IdentFragment};
use std::{collections::HashMap, fmt::Write};
use syn::{Error, Result, Token};
use syn::{parse_quote, Attribute, Error, Result, Token};

mod attr;

mod r#type;
pub use r#type::expand_type;
Expand Down Expand Up @@ -76,13 +78,180 @@ impl<'ast> ExpCtxt<'ast> {

fn expand_item(&self, item: &Item) -> Result<TokenStream> {
match item {
Item::Contract(contract) => self.expand_contract(contract),
Item::Error(error) => self.expand_error(error),
Item::Function(function) => self.expand_function(function),
Item::Struct(s) => self.expand_struct(s),
Item::Udt(udt) => self.expand_udt(udt),
}
}

fn expand_contract(&self, contract: &ItemContract) -> Result<TokenStream> {
let ItemContract {
attrs, name, body, ..
} = contract;

let mut functions = Vec::with_capacity(contract.body.len());
let mut errors = Vec::with_capacity(contract.body.len());
let mut item_tokens = TokenStream::new();
let d_attrs: Vec<Attribute> = attr::derives(attrs).cloned().collect();
for item in body {
match item {
Item::Function(function) => functions.push(function),
Item::Error(error) => errors.push(error),
_ => {}
}
item_tokens.extend(quote!(#(#d_attrs)*));
item_tokens.extend(self.expand_item(item)?);
}

let functions_enum = if functions.len() > 1 {
let mut attrs = d_attrs.clone();
let doc_str = format!("Container for all the [`{name}`] function calls.");
attrs.push(parse_quote!(#[doc = #doc_str]));
Some(self.expand_functions_enum(name, functions, &attrs))
} else {
None
};

let errors_enum = if errors.len() > 1 {
let mut attrs = d_attrs;
let doc_str = format!("Container for all the [`{name}`] custom errors.");
attrs.push(parse_quote!(#[doc = #doc_str]));
Some(self.expand_errors_enum(name, errors, &attrs))
} else {
None
};

let mod_attrs = attr::docs(attrs);
let tokens = quote! {
#(#mod_attrs)*
#[allow(non_camel_case_types, non_snake_case, clippy::style)]
pub mod #name {
#item_tokens
#functions_enum
#errors_enum
}
};
Ok(tokens)
}

fn expand_functions_enum(
&self,
name: &SolIdent,
functions: Vec<&ItemFunction>,
attrs: &[Attribute],
) -> TokenStream {
let name = format_ident!("{name}Calls");
let variants: Vec<_> = functions
.iter()
.map(|f| self.function_name_ident(f).0)
.collect();
let types: Vec<_> = variants.iter().map(|name| self.call_name(name)).collect();
let min_data_len = functions
.iter()
.map(|function| self.min_data_size(&function.arguments))
.max()
.unwrap();
let trt = Ident::new("SolCall", Span::call_site());
self.expand_call_like_enum(name, &variants, &types, min_data_len, trt, attrs)
}

fn expand_errors_enum(
&self,
name: &SolIdent,
errors: Vec<&ItemError>,
attrs: &[Attribute],
) -> TokenStream {
let name = format_ident!("{name}Errors");
let variants: Vec<_> = errors.iter().map(|error| error.name.0.clone()).collect();
let min_data_len = errors
.iter()
.map(|error| self.min_data_size(&error.fields))
.max()
.unwrap();
let trt = Ident::new("SolError", Span::call_site());
self.expand_call_like_enum(name, &variants, &variants, min_data_len, trt, attrs)
}

fn expand_call_like_enum(
&self,
name: Ident,
variants: &[Ident],
types: &[Ident],
min_data_len: usize,
trt: Ident,
attrs: &[Attribute],
) -> TokenStream {
assert_eq!(variants.len(), types.len());
let name_s = name.to_string();
let count = variants.len();
let min_data_len = min_data_len.min(4);
quote! {
#(#attrs)*
pub enum #name {#(
#variants(#types),
)*}

// TODO: Implement these functions using traits?
#[automatically_derived]
impl #name {
/// The number of variants.
pub const COUNT: usize = #count;

// no decode_raw is possible because we need the selector to know which variant to
// decode into

/// ABI-decodes the given data into one of the variants of `self`.
pub fn decode(data: &[u8], validate: bool) -> ::alloy_sol_types::Result<Self> {
if data.len() >= #min_data_len {
// TODO: Replace with `data.split_array_ref` once it's stable
let (selector, data) = data.split_at(4);
let selector: &[u8; 4] =
::core::convert::TryInto::try_into(selector).expect("unreachable");
match *selector {
#(<#types as ::alloy_sol_types::#trt>::SELECTOR => {
return <#types as ::alloy_sol_types::#trt>::decode_raw(data, validate)
.map(Self::#variants)
})*
_ => {}
}
}
::core::result::Result::Err(::alloy_sol_types::Error::type_check_fail(
data,
#name_s,
))
}

/// ABI-encodes `self` into the given buffer.
pub fn encode_raw(&self, out: &mut Vec<u8>) {
match self {#(
Self::#variants(inner) =>
<#types as ::alloy_sol_types::#trt>::encode_raw(inner, out),
)*}
}

/// ABI-encodes `self` into the given buffer.
#[inline]
pub fn encode(&self) -> Vec<u8> {
match self {#(
Self::#variants(inner) =>
<#types as ::alloy_sol_types::#trt>::encode(inner),
)*}
}
}

#(
#[automatically_derived]
impl From<#types> for #name {
fn from(value: #types) -> Self {
Self::#variants(value)
}
}
)*
}
}

fn expand_error(&self, error: &ItemError) -> Result<TokenStream> {
let ItemError {
fields,
Expand Down Expand Up @@ -344,7 +513,7 @@ impl<'ast> ExpCtxt<'ast> {
ty.visit_mut(|ty| {
let ty @ Type::Custom(_) = ty else { return };
let Type::Custom(name) = &*ty else { unreachable!() };
let Some(resolved) = self.custom_types.get(name) else { return };
let Some(resolved) = self.custom_types.get(name.last_tmp()) else { return };
ty.clone_from(resolved);
any = true;
});
Expand Down Expand Up @@ -431,6 +600,15 @@ impl<'ast> ExpCtxt<'ast> {
}
}

/// Returns the name of the function, adjusted for overloads.
fn function_name_ident(&self, function: &ItemFunction) -> SolIdent {
let sig = self.function_signature(function);
match self.function_overloads.get(&sig) {
Some(name) => SolIdent::new_spanned(name, function.name.span()),
None => function.name.clone(),
}
}

fn call_name(&self, function_name: impl IdentFragment + std::fmt::Display) -> Ident {
format_ident!("{function_name}Call")
}
Expand Down Expand Up @@ -464,7 +642,7 @@ impl<'ast> ExpCtxt<'ast> {
let mut errors = Vec::new();
params.visit_types(|ty| {
if let Type::Custom(name) = ty {
if !self.custom_types.contains_key(name) {
if !self.custom_types.contains_key(name.last_tmp()) {
let e = syn::Error::new(name.span(), "unresolved type");
errors.push(e);
}
Expand Down
21 changes: 15 additions & 6 deletions crates/sol-macro/src/expand/type.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::ExpCtxt;
use ast::{Item, SolArray, SolIdent, Type};
use ast::{Item, Parameters, SolArray, SolPath, Type};
use proc_macro2::{Literal, TokenStream};
use quote::{quote, quote_spanned, ToTokens};
use std::{fmt, num::NonZeroU16};
Expand Down Expand Up @@ -74,20 +74,29 @@ fn rec_expand_type(ty: &Type, tokens: &mut TokenStream) {
}

impl ExpCtxt<'_> {
fn get_item(&self, name: &SolIdent) -> &Item {
fn get_item(&self, name: &SolPath) -> &Item {
let name = name.last_tmp();
match self.all_items.iter().find(|item| item.name() == name) {
Some(item) => item,
None => panic!("unresolved item: {name}"),
}
}

fn custom_type(&self, name: &SolIdent) -> &Type {
match self.custom_types.get(name) {
fn custom_type(&self, name: &SolPath) -> &Type {
match self.custom_types.get(name.last_tmp()) {
Some(item) => item,
None => panic!("unresolved item: {name}"),
}
}

pub(super) fn min_data_size<P>(&self, params: &Parameters<P>) -> usize {
params
.iter()
.map(|param| self.type_base_data_size(&param.ty))
.max()
.unwrap_or(0)
}

/// Recursively calculates the base ABI-encoded size of `self` in bytes.
///
/// That is, the minimum number of bytes required to encode `self` without
Expand Down Expand Up @@ -127,7 +136,7 @@ impl ExpCtxt<'_> {
.map(|ty| self.type_base_data_size(ty))
.sum(),
Item::Udt(udt) => self.type_base_data_size(&udt.ty),
Item::Error(_) | Item::Function(_) => unreachable!(),
Item::Contract(_) | Item::Error(_) | Item::Function(_) => unreachable!(),
},
}
}
Expand Down Expand Up @@ -182,7 +191,7 @@ impl ExpCtxt<'_> {
Type::Custom(name) => match self.get_item(name) {
Item::Struct(strukt) => self.params_data_size(&strukt.fields, Some(field)),
Item::Udt(udt) => self.type_data_size(&udt.ty, field),
Item::Error(_) | Item::Function(_) => unreachable!(),
Item::Contract(_) | Item::Error(_) | Item::Function(_) => unreachable!(),
},
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/sol-types/src/types/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ pub trait SolCall: Sized {
/// selector.
#[inline]
fn decode_raw(data: &[u8], validate: bool) -> Result<Self> {
let tuple = <Self::Tuple as SolType>::decode(data, validate)?;
Ok(Self::from_rust(tuple))
<Self::Tuple as SolType>::decode(data, validate).map(Self::from_rust)
}

/// ABI decode this call's arguments from the given slice, **with** the
Expand All @@ -52,6 +51,7 @@ pub trait SolCall: Sized {
/// ABI encode the call to the given buffer **without** its selector.
#[inline]
fn encode_raw(&self, out: &mut Vec<u8>) {
out.reserve(self.data_size());
out.extend(<Self::Tuple as SolType>::encode(self.to_rust()));
}

Expand Down
40 changes: 39 additions & 1 deletion crates/sol-types/tests/doc_contracts.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,42 @@
use alloy_primitives::{Address, U256};
use alloy_sol_types::{sol, SolCall};
use hex_literal::hex;

sol! {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very cool

/// Interface of the ERC20 standard as defined in [the EIP].
///
/// [the EIP]: https://eips.ethereum.org/EIPS/eip-20
#[derive(Debug, PartialEq)]
interface IERC20 {
// TODO: Events
// event Transfer(address indexed from, address indexed to, uint256 value);
// event Approval(address indexed owner, address indexed spender, uint256 value);

function totalSupply() external view returns (uint256);
function balanceOf(address account) external view returns (uint256);
function transfer(address to, uint256 amount) external returns (bool);
function allowance(address owner, address spender) external view returns (uint256);
function approve(address spender, uint256 amount) external returns (bool);
function transferFrom(address from, address to, uint256 amount) external returns (bool);
}
}

#[test]
fn contracts() {
// TODO
// random mainnet ERC20 transfer
// https://etherscan.io/tx/0x947332ff624b5092fb92e8f02cdbb8a50314e861a4b39c29a286b3b75432165e
let data = hex!(
"a9059cbb"
"0000000000000000000000008bc47be1e3abbaba182069c89d08a61fa6c2b292"
"0000000000000000000000000000000000000000000000000000000253c51700"
);
let expected = IERC20::transferCall {
to: Address::from(hex!("8bc47be1e3abbaba182069c89d08a61fa6c2b292")),
amount: U256::from(9995360000_u64),
};

assert_eq!(data[..4], IERC20::transferCall::SELECTOR);
let decoded = IERC20::IERC20Calls::decode(&data, true).unwrap();
assert_eq!(decoded, IERC20::IERC20Calls::transfer(expected));
assert_eq!(decoded.encode(), data);
}
21 changes: 21 additions & 0 deletions crates/sol-types/tests/ui/contract.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use alloy_sol_types::sol;

sol! {
contract C {
contract Nested {}
}
}

sol! {
interface C {
library Nested {}
}
}

sol! {
abstract contract C {
interface Nested {}
}
}

fn main() {}
Loading