Skip to content

Commit

Permalink
feat(sol-macro): improve type expansion (#302)
Browse files Browse the repository at this point in the history
* feat(sol-macro): improve type expansion

* chore: update test fixtures

* chore: unremove spans

* bless
  • Loading branch information
DaniPopes authored Sep 26, 2023
1 parent f418893 commit 6ff6cc8
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 77 deletions.
23 changes: 10 additions & 13 deletions crates/sol-macro/src/expand/event.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
//! [`ItemEvent`] expansion.
use super::{anon_name, expand_tuple_types, expand_type, ExpCtxt};
use crate::expand::ty::expand_event_tokenize_func;
use super::{anon_name, expand_tuple_types, expand_type, ty, ExpCtxt};
use ast::{EventParameter, ItemEvent, SolIdent, Spanned};
use proc_macro2::TokenStream;
use quote::{quote, quote_spanned};
Expand Down Expand Up @@ -93,7 +92,7 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, event: &ItemEvent) -> Result<TokenStream>
.enumerate()
.map(|(i, p)| expand_event_topic_field(i, p, p.name.as_ref()));

let tokenize_body_impl = expand_event_tokenize_func(event.parameters.iter());
let tokenize_body_impl = ty::expand_event_tokenize_func(event.parameters.iter());

let encode_topics_impl = encode_first_topic
.into_iter()
Expand Down Expand Up @@ -176,15 +175,13 @@ fn expand_event_topic_field(
name: Option<&SolIdent>,
) -> TokenStream {
let name = anon_name((i, name));

if param.indexed_as_hash() {
quote! {
#name: <::alloy_sol_types::sol_data::FixedBytes<32> as ::alloy_sol_types::SolType>::RustType
}
let ty = if param.indexed_as_hash() {
ty::expand_rust_type(&ast::Type::FixedBytes(
name.span(),
core::num::NonZeroU16::new(32).unwrap(),
))
} else {
let ty = expand_type(&param.ty);
quote! {
#name: <#ty as ::alloy_sol_types::SolType>::RustType
}
}
ty::expand_rust_type(&param.ty)
};
quote!(#name: #ty)
}
59 changes: 26 additions & 33 deletions crates/sol-macro/src/expand/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
use crate::{
attr::{self, SolAttrs},
expand::ty::expand_rust_type,
utils::{self, ExprArray},
};
use ast::{
File, Item, ItemError, ItemEvent, ItemFunction, Parameters, SolIdent, SolPath, Spanned, Type,
VariableDeclaration, Visit,
};
use proc_macro2::{Ident, Span, TokenStream};
use quote::{format_ident, quote};
use proc_macro2::{Delimiter, Group, Ident, Punct, Spacing, Span, TokenStream, TokenTree};
use quote::{format_ident, quote, TokenStreamExt};
use std::{borrow::Borrow, collections::HashMap, fmt::Write};
use syn::{parse_quote, Attribute, Error, Result};

Expand Down Expand Up @@ -429,32 +430,21 @@ impl ExpCtxt<'_> {
}

// helper functions

/// Expands a list of parameters into a list of struct fields.
///
/// See [`expand_field`].
fn expand_fields<P>(params: &Parameters<P>) -> impl Iterator<Item = TokenStream> + '_ {
params
.iter()
.enumerate()
.map(|(i, var)| expand_field(i, &var.ty, var.name.as_ref(), var.attrs.as_ref()))
}

/// Expands a single parameter into a struct field.
fn expand_field(
i: usize,
ty: &Type,
name: Option<&SolIdent>,
attrs: &Vec<Attribute>,
) -> TokenStream {
let name = anon_name((i, name));
let ty = expand_type(ty);
quote! {
#(#attrs)*
pub #name: <#ty as ::alloy_sol_types::SolType>::RustType
}
params.iter().enumerate().map(|(i, var)| {
let name = anon_name((i, var.name.as_ref()));
let ty = expand_rust_type(&var.ty);
let attrs = &var.attrs;
quote! {
#(#attrs)*
pub #name: #ty
}
})
}

/// Generates an anonymous name from an integer. Used in `anon_name`
/// Generates an anonymous name from an integer. Used in [`anon_name`].
#[inline]
pub fn generate_name(i: usize) -> Ident {
format_ident!("_{i}")
Expand Down Expand Up @@ -516,18 +506,21 @@ fn expand_from_into_tuples<P>(name: &Ident, fields: &Parameters<P>) -> TokenStre
}
}

/// Returns
/// - `(#(#expanded,)*)`
/// - `(#(<#expanded as ::alloy_sol_types::SolType>::RustType,)*)`
/// Returns `(sol_tuple, rust_tuple)`
fn expand_tuple_types<'a, I: IntoIterator<Item = &'a Type>>(
types: I,
) -> (TokenStream, TokenStream) {
let mut sol_tuple = TokenStream::new();
let mut rust_tuple = TokenStream::new();
let mut sol = TokenStream::new();
let mut rust = TokenStream::new();
let comma = Punct::new(',', Spacing::Alone);
for ty in types {
let expanded = expand_type(ty);
sol_tuple.extend(quote!(#expanded,));
rust_tuple.extend(quote!(<#expanded as ::alloy_sol_types::SolType>::RustType,));
ty::rec_expand_type(ty, &mut sol);
sol.append(comma.clone());

ty::rec_expand_rust_type(ty, &mut rust);
rust.append(comma.clone());
}
(quote!((#sol_tuple)), quote!((#rust_tuple)))
let wrap_in_parens =
|stream| TokenStream::from(TokenTree::Group(Group::new(Delimiter::Parenthesis, stream)));
(wrap_in_parens(sol), wrap_in_parens(rust))
}
6 changes: 3 additions & 3 deletions crates/sol-macro/src/expand/struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use super::{
expand_fields, expand_from_into_tuples, expand_type, ty::expand_tokenize_func, ExpCtxt,
};
use ast::{Item, ItemStruct, Spanned, Type, VariableDeclaration};
use ast::{Item, ItemStruct, Spanned, Type};
use proc_macro2::TokenStream;
use quote::quote;
use std::num::NonZeroU16;
Expand Down Expand Up @@ -48,8 +48,8 @@ pub(super) fn expand(cx: &ExpCtxt<'_>, s: &ItemStruct) -> Result<TokenStream> {
let encode_data_impl = match fields.len() {
0 => unreachable!("struct with zero fields"),
1 => {
let VariableDeclaration { ty, name, .. } = fields.first().unwrap();
let ty = expand_type(ty);
let name = *field_names.first().unwrap();
let ty = field_types.first().unwrap();
quote!(<#ty as ::alloy_sol_types::SolType>::eip712_data_word(&self.#name).0.to_vec())
}
_ => quote! {
Expand Down
116 changes: 98 additions & 18 deletions crates/sol-macro/src/expand/ty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,29 @@
use super::ExpCtxt;
use crate::expand::generate_name;
use ast::{EventParameter, Item, Parameters, Spanned, Type, TypeArray, VariableDeclaration};
use proc_macro2::{Literal, TokenStream};
use proc_macro2::{Ident, Literal, TokenStream};
use quote::{quote, quote_spanned, ToTokens};
use std::{fmt, num::NonZeroU16};

/// Expands a single [`Type`] recursively.
/// Expands a single [`Type`] recursively to its `alloy_sol_types::sol_data`
/// equivalent.
pub fn expand_type(ty: &Type) -> TokenStream {
let mut tokens = TokenStream::new();
rec_expand_type(ty, &mut tokens);
tokens
}

/// Expands a single [`Type`] recursively to its Rust type equivalent.
///
/// This is the same as `<#expand_type(ty) as SolType>::RustType`, but generates
/// nicer code for documentation and IDE/LSP support when the type is not
/// ambiguous.
pub fn expand_rust_type(ty: &Type) -> TokenStream {
let mut tokens = TokenStream::new();
rec_expand_rust_type(ty, &mut tokens);
tokens
}

/// Expands a [`VariableDeclaration`] into an invocation of its types tokenize
/// method.
fn expand_tokenize_statement(var: &VariableDeclaration, i: usize) -> TokenStream {
Expand Down Expand Up @@ -59,35 +71,31 @@ pub fn expand_event_tokenize_func<'a>(
}

/// The [`expand_type`] recursive implementation.
fn rec_expand_type(ty: &Type, tokens: &mut TokenStream) {
pub fn rec_expand_type(ty: &Type, tokens: &mut TokenStream) {
let tts = match *ty {
Type::Address(span, _) => quote_spanned! {span=> ::alloy_sol_types::sol_data::Address },
Type::Bool(span) => quote_spanned! {span=> ::alloy_sol_types::sol_data::Bool },
Type::String(span) => quote_spanned! {span=> ::alloy_sol_types::sol_data::String },
Type::Bytes(span) => quote_spanned! {span=> ::alloy_sol_types::sol_data::Bytes },

Type::FixedBytes(span, size) => {
debug_assert!(size.get() <= 32);
assert!(size.get() <= 32);
let size = Literal::u16_unsuffixed(size.get());
quote_spanned! {span=>
::alloy_sol_types::sol_data::FixedBytes<#size>
}
quote_spanned! {span=> ::alloy_sol_types::sol_data::FixedBytes<#size> }
}
Type::Int(span, size) | Type::Uint(span, size) => {
let name = match ty {
Type::Int(..) => "Int",
Type::Uint(..) => "Uint",
_ => unreachable!(),
};
let name = syn::Ident::new(name, span);
let name = Ident::new(name, span);

let size = size.map_or(256, NonZeroU16::get);
debug_assert!(size <= 256 && size % 8 == 0);
assert!(size <= 256 && size % 8 == 0);
let size = Literal::u16_unsuffixed(size);

quote_spanned! {span=>
::alloy_sol_types::sol_data::#name<#size>
}
quote_spanned! {span=> ::alloy_sol_types::sol_data::#name<#size> }
}

Type::Tuple(ref tuple) => {
Expand All @@ -103,13 +111,9 @@ fn rec_expand_type(ty: &Type, tokens: &mut TokenStream) {
let ty = expand_type(&array.ty);
let span = array.span();
if let Some(size) = array.size() {
quote_spanned! {span=>
::alloy_sol_types::sol_data::FixedArray<#ty, #size>
}
quote_spanned! {span=> ::alloy_sol_types::sol_data::FixedArray<#ty, #size> }
} else {
quote_spanned! {span=>
::alloy_sol_types::sol_data::Array<#ty>
}
quote_spanned! {span=> ::alloy_sol_types::sol_data::Array<#ty> }
}
}
Type::Function(ref function) => quote_spanned! {function.span()=>
Expand All @@ -124,6 +128,82 @@ fn rec_expand_type(ty: &Type, tokens: &mut TokenStream) {
tokens.extend(tts);
}

// IMPORTANT: Keep in sync with `sol-types/src/types/data_type.rs`
/// The [`expand_rust_type`] recursive implementation.
pub fn rec_expand_rust_type(ty: &Type, tokens: &mut TokenStream) {
// Display sizes that match with the Rust type, otherwise we lose information
// (e.g. `uint24` displays the same as `uint32` because both use `u32`)
fn allowed_int_size(size: Option<NonZeroU16>) -> bool {
matches!(
size.map_or(256, NonZeroU16::get),
8 | 16 | 32 | 64 | 128 | 256
)
}

let tts = match *ty {
Type::Address(span, _) => quote_spanned! {span=> ::alloy_sol_types::private::Address },
Type::Bool(span) => return Ident::new("bool", span).to_tokens(tokens),
Type::String(span) => quote_spanned! {span=> ::alloy_sol_types::private::String },
Type::Bytes(span) => quote_spanned! {span=> ::alloy_sol_types::private::Vec<u8> },

Type::FixedBytes(span, size) => {
assert!(size.get() <= 32);
let size = Literal::u16_unsuffixed(size.get());
quote_spanned! {span=> ::alloy_sol_types::private::FixedBytes<#size> }
}
Type::Int(span, size) | Type::Uint(span, size) if allowed_int_size(size) => {
let size = size.map_or(256, NonZeroU16::get);
if size <= 128 {
let name = match ty {
Type::Int(..) => "i",
Type::Uint(..) => "u",
_ => unreachable!(),
};
return Ident::new(&format!("{name}{size}"), span).to_tokens(tokens)
}
assert_eq!(size, 256);
match ty {
Type::Int(..) => quote_spanned! {span=> ::alloy_sol_types::private::I256 },
Type::Uint(..) => quote_spanned! {span=> ::alloy_sol_types::private::U256 },
_ => unreachable!(),
}
}

Type::Tuple(ref tuple) => {
return tuple.paren_token.surround(tokens, |tokens| {
for pair in tuple.types.pairs() {
let (ty, comma) = pair.into_tuple();
rec_expand_rust_type(ty, tokens);
comma.to_tokens(tokens);
}
})
}
Type::Array(ref array) => {
let ty = expand_rust_type(&array.ty);
let span = array.span();
if let Some(size) = array.size() {
quote_spanned! {span=> [#ty; #size] }
} else {
quote_spanned! {span=> ::alloy_sol_types::private::Vec<#ty> }
}
}
Type::Function(ref function) => quote_spanned! {function.span()=>
::alloy_sol_types::private::Function
},
Type::Mapping(ref mapping) => quote_spanned! {mapping.span()=>
::core::compile_error!("Mapping types are not supported here")
},

// Exhaustive fallback to `SolType::RustType`
ref ty @ (Type::Int(..) | Type::Uint(..) | Type::Custom(_)) => {
let span = ty.span();
let ty = expand_type(ty);
quote_spanned! {span=> <#ty as ::alloy_sol_types::SolType>::RustType }
}
};
tokens.extend(tts);
}

/// Calculates the base ABI-encoded size of the given parameters in bytes.
///
/// See [`type_base_data_size`] for more information.
Expand Down
4 changes: 3 additions & 1 deletion crates/sol-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ pub mod private {
string::{String, ToString},
vec::Vec,
};
pub use alloy_primitives::{bytes, keccak256, Bytes, FixedBytes, B256, U256};
pub use alloy_primitives::{
bytes, keccak256, Address, Bytes, FixedBytes, Function, Signed, Uint, B256, I256, U256,
};
pub use core::{convert::From, default::Default, option::Option, result::Result};

pub use Option::{None, Some};
Expand Down
3 changes: 3 additions & 0 deletions crates/sol-types/src/types/data_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ use alloy_primitives::{
};
use core::{borrow::Borrow, fmt::*, hash::Hash, marker::PhantomData, ops::*};

// IMPORTANT: Keep in sync with `rec_expand_rust_type` in
// `sol-macro/src/expand/ty.rs`

/// Bool - `bool`
pub struct Bool;

Expand Down
13 changes: 8 additions & 5 deletions crates/sol-types/tests/sol.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use alloy_primitives::{keccak256, Address, B256, U256};
use alloy_primitives::{keccak256, Address, B256, I256, U256};
use alloy_sol_types::{eip712_domain, sol, SolCall, SolError, SolStruct, SolType};
use serde::Serialize;
use serde_json::Value;
Expand Down Expand Up @@ -133,15 +133,18 @@ fn function() {
#[test]
fn error() {
sol! {
error SomeError(uint256 a);
error SomeError(int a, bool b);
}

let sig = "SomeError(uint256)";
let sig = "SomeError(int256,bool)";
assert_eq!(SomeError::SIGNATURE, sig);
assert_eq!(SomeError::SELECTOR, keccak256(sig)[..4]);

let e = SomeError { a: U256::from(1) };
assert_eq!(e.encoded_size(), 32);
let e = SomeError {
a: I256::ZERO,
b: false,
};
assert_eq!(e.encoded_size(), 64);
}

// https://github.com/alloy-rs/core/issues/158
Expand Down
Loading

0 comments on commit 6ff6cc8

Please sign in to comment.