Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add a derive for generate contract #58

Merged
merged 3 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions crates/rs-macro/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ The `abigen!` macro takes 2 or 3 inputs:
3. Optional parameters:
- `output_path`: if provided, the content will be generated in the given file instead of being expanded at the location of the macro invocation.
- `type_aliases`: to avoid type name conflicts between components / contracts, you can rename some type by providing an alias for the full type path. It is important to give the **full** type path to ensure aliases are applied correctly.
- `derive`: to specify the derive for the generated structs/enums.
- `contract_derives`: to specify the derive for the generated contract type.

```rust
use cainome::rs::abigen;
Expand All @@ -66,6 +68,14 @@ abigen!(
},
);

// Example with custom derives:
abigen!(
MyContract,
"./contracts/abi/components.abi.json",
derive(Debug, Clone),
contract_derives(Debug, Clone)
);

fn main() {
// ... use the generated types here, which all of them
// implement CairoSerde trait.
Expand Down
2 changes: 2 additions & 0 deletions crates/rs-macro/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ fn abigen_internal(input: TokenStream) -> TokenStream {
&abi_tokens,
contract_abi.execution_version,
&contract_abi.derives,
&contract_abi.contract_derives,
);

if let Some(out_path) = contract_abi.output_path {
Expand Down Expand Up @@ -66,6 +67,7 @@ fn abigen_internal_legacy(input: TokenStream) -> TokenStream {
&abi_tokens,
cainome_rs::ExecutionVersion::V1,
&contract_abi.derives,
&contract_abi.contract_derives,
);

if let Some(out_path) = contract_abi.output_path {
Expand Down
12 changes: 12 additions & 0 deletions crates/rs-macro/src/macro_inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ pub(crate) struct ContractAbi {
pub type_aliases: HashMap<String, String>,
pub execution_version: ExecutionVersion,
pub derives: Vec<String>,
pub contract_derives: Vec<String>,
}

impl Parse for ContractAbi {
Expand Down Expand Up @@ -92,6 +93,7 @@ impl Parse for ContractAbi {
let mut execution_version = ExecutionVersion::V1;
let mut type_aliases = HashMap::new();
let mut derives = Vec::new();
let mut contract_derives = Vec::new();

loop {
if input.parse::<Token![,]>().is_err() {
Expand Down Expand Up @@ -153,6 +155,15 @@ impl Parse for ContractAbi {
derives.push(derive.to_token_stream().to_string());
}
}
"contract_derives" => {
let content;
parenthesized!(content in input);
let parsed = content.parse_terminated(Spanned::<Type>::parse, Token![,])?;

for derive in parsed {
contract_derives.push(derive.to_token_stream().to_string());
}
}
_ => emit_error!(name.span(), format!("unexpected named parameter `{name}`")),
}
}
Expand All @@ -164,6 +175,7 @@ impl Parse for ContractAbi {
type_aliases,
execution_version,
derives,
contract_derives,
})
}
}
Expand Down
12 changes: 12 additions & 0 deletions crates/rs-macro/src/macro_inputs_legacy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub(crate) struct ContractAbiLegacy {
pub output_path: Option<String>,
pub type_aliases: HashMap<String, String>,
pub derives: Vec<String>,
pub contract_derives: Vec<String>,
}

impl Parse for ContractAbiLegacy {
Expand Down Expand Up @@ -89,6 +90,7 @@ impl Parse for ContractAbiLegacy {
let mut output_path: Option<String> = None;
let mut type_aliases = HashMap::new();
let mut derives = Vec::new();
let mut contract_derives = Vec::new();

loop {
if input.parse::<Token![,]>().is_err() {
Expand Down Expand Up @@ -142,6 +144,15 @@ impl Parse for ContractAbiLegacy {
derives.push(derive.to_token_stream().to_string());
}
}
"contract_derives" => {
let content;
parenthesized!(content in input);
let parsed = content.parse_terminated(Spanned::<Type>::parse, Token![,])?;

for derive in parsed {
contract_derives.push(derive.to_token_stream().to_string());
}
}
_ => emit_error!(name.span(), format!("unexpected named parameter `{name}`")),
}
}
Expand All @@ -152,6 +163,7 @@ impl Parse for ContractAbiLegacy {
output_path,
type_aliases,
derives,
contract_derives,
})
}
}
Expand Down
12 changes: 9 additions & 3 deletions crates/rs/src/expand/contract.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,22 @@ use super::utils;
pub struct CairoContract;

impl CairoContract {
pub fn expand(contract_name: Ident) -> TokenStream2 {
pub fn expand(contract_name: Ident, contract_derives: &[String]) -> TokenStream2 {
let reader = utils::str_to_ident(format!("{}Reader", contract_name).as_str());

let snrs_types = utils::snrs_types();
let snrs_accounts = utils::snrs_accounts();
let snrs_providers = utils::snrs_providers();

let mut internal_derives = vec![];

for d in contract_derives {
internal_derives.push(utils::str_to_type(d));
}

let q = quote! {

#[derive(Debug)]
#[derive(#(#internal_derives,)*)]
pub struct #contract_name<A: #snrs_accounts::ConnectedAccount + Sync> {
pub address: #snrs_types::Felt,
pub account: A,
Expand Down Expand Up @@ -45,7 +51,7 @@ impl CairoContract {
}
}

#[derive(Debug)]
#[derive(#(#internal_derives,)*)]
pub struct #reader<P: #snrs_providers::Provider + Sync> {
pub address: #snrs_types::Felt,
pub provider: P,
Expand Down
23 changes: 22 additions & 1 deletion crates/rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ pub struct Abigen {
pub execution_version: ExecutionVersion,
/// Derives to be added to the generated types.
pub derives: Vec<String>,
/// Derives to be added to the generated contract.
pub contract_derives: Vec<String>,
}

impl Abigen {
Expand All @@ -90,6 +92,7 @@ impl Abigen {
types_aliases: HashMap::new(),
execution_version: ExecutionVersion::V1,
derives: vec![],
contract_derives: vec![],
}
}

Expand Down Expand Up @@ -123,6 +126,16 @@ impl Abigen {
self
}

/// Sets the derives to be added to the generated contract.
///
/// # Arguments
///
/// * `derives` - Derives to be added to the generated contract.
pub fn with_contract_derives(mut self, derives: Vec<String>) -> Self {
self.contract_derives = derives;
self
}

/// Generates the contract bindings.
pub fn generate(&self) -> Result<ContractBindings> {
let file_content = std::fs::read_to_string(&self.abi_source)?;
Expand All @@ -134,6 +147,7 @@ impl Abigen {
&tokens,
self.execution_version,
&self.derives,
&self.contract_derives,
);

Ok(ContractBindings {
Expand All @@ -157,17 +171,24 @@ impl Abigen {
///
/// * `contract_name` - Name of the contract.
/// * `abi_tokens` - Tokenized ABI.
/// * `execution_version` - The version of transaction to be executed.
/// * `derives` - Derives to be added to the generated types.
/// * `contract_derives` - Derives to be added to the generated contract.
pub fn abi_to_tokenstream(
contract_name: &str,
abi_tokens: &TokenizedAbi,
execution_version: ExecutionVersion,
derives: &[String],
contract_derives: &[String],
) -> TokenStream2 {
let contract_name = utils::str_to_ident(contract_name);

let mut tokens: Vec<TokenStream2> = vec![];

tokens.push(CairoContract::expand(contract_name.clone()));
tokens.push(CairoContract::expand(
contract_name.clone(),
contract_derives,
));

let mut sorted_structs = abi_tokens.structs.clone();
sorted_structs.sort_by(|a, b| {
Expand Down
4 changes: 3 additions & 1 deletion examples/abigen_generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ async fn main() {
"MyContract",
"./contracts/target/dev/contracts_simple_get_set.contract_class.json",
)
.with_types_aliases(aliases);
.with_types_aliases(aliases)
.with_derives(vec!["Debug".to_string(), "PartialEq".to_string()])
.with_contract_derives(vec!["Debug".to_string(), "Clone".to_string()]);

abigen
.generate()
Expand Down
4 changes: 3 additions & 1 deletion examples/simple_get_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ const KATANA_CHAIN_ID: &str = "0x4b4154414e41";
// Or you can use the extracted abi entries with jq in contracts/abi/.
abigen!(
MyContract,
"./contracts/target/dev/contracts_simple_get_set.contract_class.json"
"./contracts/target/dev/contracts_simple_get_set.contract_class.json",
derives(Debug, PartialEq),
contract_derives(Debug, Clone)
);
//abigen!(MyContract, "./contracts/abi/simple_get_set.abi.json");

Expand Down
5 changes: 5 additions & 0 deletions src/bin/cli/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ pub struct CainomeArgs {
#[arg(value_name = "DERIVES")]
#[arg(help = "Derives to be added to the generated types.")]
pub derives: Option<Vec<String>>,

#[arg(long)]
#[arg(value_name = "CONTRACT_DERIVES")]
#[arg(help = "Derives to be added to the generated contract.")]
pub contract_derives: Option<Vec<String>>,
}

#[derive(Debug, Args, Clone)]
Expand Down
1 change: 1 addition & 0 deletions src/bin/cli/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ async fn main() -> CainomeCliResult<()> {
contracts,
execution_version: args.execution_version,
derives: args.derives.unwrap_or_default(),
contract_derives: args.contract_derives.unwrap_or_default(),
})
.await?;

Expand Down
1 change: 1 addition & 0 deletions src/bin/cli/plugins/builtins/rust.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ impl BuiltinPlugin for RustPlugin {
&contract.tokens,
input.execution_version,
&input.derives,
&input.contract_derives,
);
let filename = format!(
"{}.rs",
Expand Down
1 change: 1 addition & 0 deletions src/bin/cli/plugins/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub struct PluginInput {
pub contracts: Vec<ContractData>,
pub execution_version: ExecutionVersion,
pub derives: Vec<String>,
pub contract_derives: Vec<String>,
}

#[derive(Debug)]
Expand Down
Loading