|
1 | 1 | //! Functions which generate Rust code from the Solidity AST.
|
2 | 2 |
|
3 | 3 | use ast::{
|
4 |
| - File, Item, ItemError, ItemFunction, ItemStruct, ItemUdt, Parameters, SolIdent, Type, |
5 |
| - VariableDeclaration, Visit, |
| 4 | + File, Item, ItemContract, ItemError, ItemFunction, ItemStruct, ItemUdt, Parameters, SolIdent, |
| 5 | + Type, VariableDeclaration, Visit, |
6 | 6 | };
|
7 | 7 | use proc_macro2::{Ident, Span, TokenStream};
|
8 | 8 | use quote::{format_ident, quote, IdentFragment};
|
9 | 9 | use std::{collections::HashMap, fmt::Write};
|
10 |
| -use syn::{Error, Result, Token}; |
| 10 | +use syn::{parse_quote, Attribute, Error, Result, Token}; |
| 11 | + |
| 12 | +mod attr; |
11 | 13 |
|
12 | 14 | mod r#type;
|
13 | 15 | pub use r#type::expand_type;
|
@@ -76,13 +78,180 @@ impl<'ast> ExpCtxt<'ast> {
|
76 | 78 |
|
77 | 79 | fn expand_item(&self, item: &Item) -> Result<TokenStream> {
|
78 | 80 | match item {
|
| 81 | + Item::Contract(contract) => self.expand_contract(contract), |
79 | 82 | Item::Error(error) => self.expand_error(error),
|
80 | 83 | Item::Function(function) => self.expand_function(function),
|
81 | 84 | Item::Struct(s) => self.expand_struct(s),
|
82 | 85 | Item::Udt(udt) => self.expand_udt(udt),
|
83 | 86 | }
|
84 | 87 | }
|
85 | 88 |
|
| 89 | + fn expand_contract(&self, contract: &ItemContract) -> Result<TokenStream> { |
| 90 | + let ItemContract { |
| 91 | + attrs, name, body, .. |
| 92 | + } = contract; |
| 93 | + |
| 94 | + let mut functions = Vec::with_capacity(contract.body.len()); |
| 95 | + let mut errors = Vec::with_capacity(contract.body.len()); |
| 96 | + let mut item_tokens = TokenStream::new(); |
| 97 | + let d_attrs: Vec<Attribute> = attr::derives(attrs).cloned().collect(); |
| 98 | + for item in body { |
| 99 | + match item { |
| 100 | + Item::Function(function) => functions.push(function), |
| 101 | + Item::Error(error) => errors.push(error), |
| 102 | + _ => {} |
| 103 | + } |
| 104 | + item_tokens.extend(quote!(#(#d_attrs)*)); |
| 105 | + item_tokens.extend(self.expand_item(item)?); |
| 106 | + } |
| 107 | + |
| 108 | + let functions_enum = if functions.len() > 1 { |
| 109 | + let mut attrs = d_attrs.clone(); |
| 110 | + let doc_str = format!("Container for all the [`{name}`] function calls."); |
| 111 | + attrs.push(parse_quote!(#[doc = #doc_str])); |
| 112 | + Some(self.expand_functions_enum(name, functions, &attrs)) |
| 113 | + } else { |
| 114 | + None |
| 115 | + }; |
| 116 | + |
| 117 | + let errors_enum = if errors.len() > 1 { |
| 118 | + let mut attrs = d_attrs; |
| 119 | + let doc_str = format!("Container for all the [`{name}`] custom errors."); |
| 120 | + attrs.push(parse_quote!(#[doc = #doc_str])); |
| 121 | + Some(self.expand_errors_enum(name, errors, &attrs)) |
| 122 | + } else { |
| 123 | + None |
| 124 | + }; |
| 125 | + |
| 126 | + let mod_attrs = attr::docs(attrs); |
| 127 | + let tokens = quote! { |
| 128 | + #(#mod_attrs)* |
| 129 | + #[allow(non_camel_case_types, non_snake_case, clippy::style)] |
| 130 | + pub mod #name { |
| 131 | + #item_tokens |
| 132 | + #functions_enum |
| 133 | + #errors_enum |
| 134 | + } |
| 135 | + }; |
| 136 | + Ok(tokens) |
| 137 | + } |
| 138 | + |
| 139 | + fn expand_functions_enum( |
| 140 | + &self, |
| 141 | + name: &SolIdent, |
| 142 | + functions: Vec<&ItemFunction>, |
| 143 | + attrs: &[Attribute], |
| 144 | + ) -> TokenStream { |
| 145 | + let name = format_ident!("{name}Calls"); |
| 146 | + let variants: Vec<_> = functions |
| 147 | + .iter() |
| 148 | + .map(|f| self.function_name_ident(f).0) |
| 149 | + .collect(); |
| 150 | + let types: Vec<_> = variants.iter().map(|name| self.call_name(name)).collect(); |
| 151 | + let min_data_len = functions |
| 152 | + .iter() |
| 153 | + .map(|function| self.min_data_size(&function.arguments)) |
| 154 | + .max() |
| 155 | + .unwrap(); |
| 156 | + let trt = Ident::new("SolCall", Span::call_site()); |
| 157 | + self.expand_call_like_enum(name, &variants, &types, min_data_len, trt, attrs) |
| 158 | + } |
| 159 | + |
| 160 | + fn expand_errors_enum( |
| 161 | + &self, |
| 162 | + name: &SolIdent, |
| 163 | + errors: Vec<&ItemError>, |
| 164 | + attrs: &[Attribute], |
| 165 | + ) -> TokenStream { |
| 166 | + let name = format_ident!("{name}Errors"); |
| 167 | + let variants: Vec<_> = errors.iter().map(|error| error.name.0.clone()).collect(); |
| 168 | + let min_data_len = errors |
| 169 | + .iter() |
| 170 | + .map(|error| self.min_data_size(&error.fields)) |
| 171 | + .max() |
| 172 | + .unwrap(); |
| 173 | + let trt = Ident::new("SolError", Span::call_site()); |
| 174 | + self.expand_call_like_enum(name, &variants, &variants, min_data_len, trt, attrs) |
| 175 | + } |
| 176 | + |
| 177 | + fn expand_call_like_enum( |
| 178 | + &self, |
| 179 | + name: Ident, |
| 180 | + variants: &[Ident], |
| 181 | + types: &[Ident], |
| 182 | + min_data_len: usize, |
| 183 | + trt: Ident, |
| 184 | + attrs: &[Attribute], |
| 185 | + ) -> TokenStream { |
| 186 | + assert_eq!(variants.len(), types.len()); |
| 187 | + let name_s = name.to_string(); |
| 188 | + let count = variants.len(); |
| 189 | + let min_data_len = min_data_len.min(4); |
| 190 | + quote! { |
| 191 | + #(#attrs)* |
| 192 | + pub enum #name {#( |
| 193 | + #variants(#types), |
| 194 | + )*} |
| 195 | + |
| 196 | + // TODO: Implement these functions using traits? |
| 197 | + #[automatically_derived] |
| 198 | + impl #name { |
| 199 | + /// The number of variants. |
| 200 | + pub const COUNT: usize = #count; |
| 201 | + |
| 202 | + // no decode_raw is possible because we need the selector to know which variant to |
| 203 | + // decode into |
| 204 | + |
| 205 | + /// ABI-decodes the given data into one of the variants of `self`. |
| 206 | + pub fn decode(data: &[u8], validate: bool) -> ::alloy_sol_types::Result<Self> { |
| 207 | + if data.len() >= #min_data_len { |
| 208 | + // TODO: Replace with `data.split_array_ref` once it's stable |
| 209 | + let (selector, data) = data.split_at(4); |
| 210 | + let selector: &[u8; 4] = |
| 211 | + ::core::convert::TryInto::try_into(selector).expect("unreachable"); |
| 212 | + match *selector { |
| 213 | + #(<#types as ::alloy_sol_types::#trt>::SELECTOR => { |
| 214 | + return <#types as ::alloy_sol_types::#trt>::decode_raw(data, validate) |
| 215 | + .map(Self::#variants) |
| 216 | + })* |
| 217 | + _ => {} |
| 218 | + } |
| 219 | + } |
| 220 | + ::core::result::Result::Err(::alloy_sol_types::Error::type_check_fail( |
| 221 | + data, |
| 222 | + #name_s, |
| 223 | + )) |
| 224 | + } |
| 225 | + |
| 226 | + /// ABI-encodes `self` into the given buffer. |
| 227 | + pub fn encode_raw(&self, out: &mut Vec<u8>) { |
| 228 | + match self {#( |
| 229 | + Self::#variants(inner) => |
| 230 | + <#types as ::alloy_sol_types::#trt>::encode_raw(inner, out), |
| 231 | + )*} |
| 232 | + } |
| 233 | + |
| 234 | + /// ABI-encodes `self` into the given buffer. |
| 235 | + #[inline] |
| 236 | + pub fn encode(&self) -> Vec<u8> { |
| 237 | + match self {#( |
| 238 | + Self::#variants(inner) => |
| 239 | + <#types as ::alloy_sol_types::#trt>::encode(inner), |
| 240 | + )*} |
| 241 | + } |
| 242 | + } |
| 243 | + |
| 244 | + #( |
| 245 | + #[automatically_derived] |
| 246 | + impl From<#types> for #name { |
| 247 | + fn from(value: #types) -> Self { |
| 248 | + Self::#variants(value) |
| 249 | + } |
| 250 | + } |
| 251 | + )* |
| 252 | + } |
| 253 | + } |
| 254 | + |
86 | 255 | fn expand_error(&self, error: &ItemError) -> Result<TokenStream> {
|
87 | 256 | let ItemError {
|
88 | 257 | fields,
|
@@ -344,7 +513,7 @@ impl<'ast> ExpCtxt<'ast> {
|
344 | 513 | ty.visit_mut(|ty| {
|
345 | 514 | let ty @ Type::Custom(_) = ty else { return };
|
346 | 515 | let Type::Custom(name) = &*ty else { unreachable!() };
|
347 |
| - let Some(resolved) = self.custom_types.get(name) else { return }; |
| 516 | + let Some(resolved) = self.custom_types.get(name.last_tmp()) else { return }; |
348 | 517 | ty.clone_from(resolved);
|
349 | 518 | any = true;
|
350 | 519 | });
|
@@ -431,6 +600,15 @@ impl<'ast> ExpCtxt<'ast> {
|
431 | 600 | }
|
432 | 601 | }
|
433 | 602 |
|
| 603 | + /// Returns the name of the function, adjusted for overloads. |
| 604 | + fn function_name_ident(&self, function: &ItemFunction) -> SolIdent { |
| 605 | + let sig = self.function_signature(function); |
| 606 | + match self.function_overloads.get(&sig) { |
| 607 | + Some(name) => SolIdent::new_spanned(name, function.name.span()), |
| 608 | + None => function.name.clone(), |
| 609 | + } |
| 610 | + } |
| 611 | + |
434 | 612 | fn call_name(&self, function_name: impl IdentFragment + std::fmt::Display) -> Ident {
|
435 | 613 | format_ident!("{function_name}Call")
|
436 | 614 | }
|
@@ -464,7 +642,7 @@ impl<'ast> ExpCtxt<'ast> {
|
464 | 642 | let mut errors = Vec::new();
|
465 | 643 | params.visit_types(|ty| {
|
466 | 644 | if let Type::Custom(name) = ty {
|
467 |
| - if !self.custom_types.contains_key(name) { |
| 645 | + if !self.custom_types.contains_key(name.last_tmp()) { |
468 | 646 | let e = syn::Error::new(name.span(), "unresolved type");
|
469 | 647 | errors.push(e);
|
470 | 648 | }
|
|
0 commit comments