diff --git a/plugins/cairo-lang-macro-attributes/src/lib.rs b/plugins/cairo-lang-macro-attributes/src/lib.rs index 236032dde..3457f9c4f 100644 --- a/plugins/cairo-lang-macro-attributes/src/lib.rs +++ b/plugins/cairo-lang-macro-attributes/src/lib.rs @@ -24,6 +24,16 @@ pub fn inline_macro(_args: TokenStream, input: TokenStream) -> TokenStream { macro_helper(input, quote!(::cairo_lang_macro::ExpansionKind::Inline)) } +/// Constructs the derive macro implementation. +/// +/// This macro hides the conversion to stable ABI structs from the user. +/// +/// Note, that this macro can be used multiple times, to define multiple independent attribute macros. +#[proc_macro_attribute] +pub fn derive_macro(_args: TokenStream, input: TokenStream) -> TokenStream { + macro_helper(input, quote!(::cairo_lang_macro::ExpansionKind::Derive)) +} + fn macro_helper(input: TokenStream, kind: impl ToTokens) -> TokenStream { let item: ItemFn = parse_macro_input!(input as ItemFn); let original_item_name = item.sig.ident.to_string(); diff --git a/scarb/src/compiler/plugin/proc_macro/host.rs b/scarb/src/compiler/plugin/proc_macro/host.rs index fa419d78d..74e6d9117 100644 --- a/scarb/src/compiler/plugin/proc_macro/host.rs +++ b/scarb/src/compiler/plugin/proc_macro/host.rs @@ -2,6 +2,7 @@ use crate::compiler::plugin::proc_macro::{Expansion, FromSyntaxNode, ProcMacroIn use crate::core::{Config, Package, PackageId}; use anyhow::{ensure, Result}; use cairo_lang_defs::ids::{ModuleItemId, TopLevelLanguageElementId}; +use cairo_lang_defs::patcher::PatchBuilder; use cairo_lang_defs::plugin::{ DynGeneratedFileAuxData, GeneratedFileAuxData, MacroPlugin, MacroPluginMetadata, PluginGeneratedFile, PluginResult, @@ -15,21 +16,24 @@ use cairo_lang_semantic::db::SemanticGroup; use cairo_lang_semantic::items::attribute::SemanticQueryAttrs; use cairo_lang_semantic::plugin::PluginSuite; use cairo_lang_syntax::attribute::structured::{ - Attribute, AttributeArgVariant, AttributeListStructurize, + Attribute, AttributeArgVariant, AttributeStructurize, }; -use cairo_lang_syntax::node::ast::Expr; +use cairo_lang_syntax::node::ast::{Expr, PathSegment}; use cairo_lang_syntax::node::db::SyntaxGroup; +use cairo_lang_syntax::node::helpers::QueryAttrs; use cairo_lang_syntax::node::ids::SyntaxStablePtrId; -use cairo_lang_syntax::node::{ast, TypedStablePtr, TypedSyntaxNode}; +use cairo_lang_syntax::node::{ast, Terminal, TypedStablePtr, TypedSyntaxNode}; use itertools::Itertools; use scarb_stable_hash::short_hash; use std::any::Any; use std::collections::HashMap; +use std::fmt::Debug; use std::sync::{Arc, RwLock}; use std::vec::IntoIter; use tracing::{debug, trace_span}; const FULL_PATH_MARKER_KEY: &str = "macro::full_path_marker"; +const DERIVE_ATTR: &str = "derive"; /// A Cairo compiler plugin controlling the procedural macro execution. /// @@ -88,6 +92,10 @@ impl GeneratedFileAuxData for EmittedAuxData { } impl EmittedAuxData { + pub fn new(aux_data: ProcMacroAuxData) -> Self { + Self(vec![aux_data]) + } + pub fn push(&mut self, aux_data: ProcMacroAuxData) { self.0.push(aux_data); } @@ -142,37 +150,245 @@ impl ProcMacroHostPlugin { }) } - /// Handle `#[proc_macro_name]` attribute. - fn handle_attribute( + /// Find first attribute procedural macros that should be expanded. + /// + /// Remove the attribute from the code. + fn parse_attribute( &self, db: &dyn SyntaxGroup, item_ast: ast::ModuleItem, - ) -> Vec { - let attrs = match item_ast { - ast::ModuleItem::Struct(struct_ast) => Some(struct_ast.attributes(db)), - ast::ModuleItem::Enum(enum_ast) => Some(enum_ast.attributes(db)), - ast::ModuleItem::ExternType(extern_type_ast) => Some(extern_type_ast.attributes(db)), + ) -> (Option, TokenStream) { + let mut item_builder = PatchBuilder::new(db); + let input = match item_ast { + ast::ModuleItem::Struct(struct_ast) => { + let attrs = struct_ast.attributes(db).elements(db); + let expansion = self.parse_attrs(db, &mut item_builder, attrs); + item_builder.add_node(struct_ast.visibility(db).as_syntax_node()); + item_builder.add_node(struct_ast.struct_kw(db).as_syntax_node()); + item_builder.add_node(struct_ast.name(db).as_syntax_node()); + item_builder.add_node(struct_ast.generic_params(db).as_syntax_node()); + item_builder.add_node(struct_ast.lbrace(db).as_syntax_node()); + item_builder.add_node(struct_ast.members(db).as_syntax_node()); + item_builder.add_node(struct_ast.rbrace(db).as_syntax_node()); + expansion + } + ast::ModuleItem::Enum(enum_ast) => { + let attrs = enum_ast.attributes(db).elements(db); + let expansion = self.parse_attrs(db, &mut item_builder, attrs); + item_builder.add_node(enum_ast.visibility(db).as_syntax_node()); + item_builder.add_node(enum_ast.enum_kw(db).as_syntax_node()); + item_builder.add_node(enum_ast.name(db).as_syntax_node()); + item_builder.add_node(enum_ast.generic_params(db).as_syntax_node()); + item_builder.add_node(enum_ast.lbrace(db).as_syntax_node()); + item_builder.add_node(enum_ast.variants(db).as_syntax_node()); + item_builder.add_node(enum_ast.rbrace(db).as_syntax_node()); + expansion + } + ast::ModuleItem::ExternType(extern_type_ast) => { + let attrs = extern_type_ast.attributes(db).elements(db); + let expansion = self.parse_attrs(db, &mut item_builder, attrs); + item_builder.add_node(extern_type_ast.visibility(db).as_syntax_node()); + item_builder.add_node(extern_type_ast.extern_kw(db).as_syntax_node()); + item_builder.add_node(extern_type_ast.type_kw(db).as_syntax_node()); + item_builder.add_node(extern_type_ast.name(db).as_syntax_node()); + item_builder.add_node(extern_type_ast.generic_params(db).as_syntax_node()); + item_builder.add_node(extern_type_ast.semicolon(db).as_syntax_node()); + expansion + } ast::ModuleItem::ExternFunction(extern_func_ast) => { - Some(extern_func_ast.attributes(db)) + let attrs = extern_func_ast.attributes(db).elements(db); + let expansion = self.parse_attrs(db, &mut item_builder, attrs); + item_builder.add_node(extern_func_ast.visibility(db).as_syntax_node()); + item_builder.add_node(extern_func_ast.extern_kw(db).as_syntax_node()); + item_builder.add_node(extern_func_ast.declaration(db).as_syntax_node()); + item_builder.add_node(extern_func_ast.semicolon(db).as_syntax_node()); + expansion + } + ast::ModuleItem::FreeFunction(free_func_ast) => { + let attrs = free_func_ast.attributes(db).elements(db); + let expansion = self.parse_attrs(db, &mut item_builder, attrs); + item_builder.add_node(free_func_ast.visibility(db).as_syntax_node()); + item_builder.add_node(free_func_ast.declaration(db).as_syntax_node()); + item_builder.add_node(free_func_ast.body(db).as_syntax_node()); + expansion + } + _ => None, + }; + let token_stream = TokenStream::new(item_builder.code); + (input, token_stream) + } + + fn parse_attrs( + &self, + db: &dyn SyntaxGroup, + builder: &mut PatchBuilder<'_>, + attrs: Vec, + ) -> Option { + let mut expansion = None; + for attr in attrs { + if expansion.is_none() { + let structured_attr = attr.clone().structurize(db); + let found = self.find_expansion(&Expansion::new( + structured_attr.id.clone(), + ExpansionKind::Attr, + )); + if found.is_some() { + expansion = found; + // Do not add the attribute for found expansion. + continue; + } } - ast::ModuleItem::FreeFunction(free_func_ast) => Some(free_func_ast.attributes(db)), + builder.add_node(attr.as_syntax_node()); + } + expansion + } + + /// Handle `#[derive(...)]` attribute. + /// + /// Returns a list of expansions that this plugin should apply. + fn parse_derive(&self, db: &dyn SyntaxGroup, item_ast: ast::ModuleItem) -> Vec { + let attrs = match item_ast { + ast::ModuleItem::Struct(struct_ast) => Some(struct_ast.query_attr(db, DERIVE_ATTR)), + ast::ModuleItem::Enum(enum_ast) => Some(enum_ast.query_attr(db, DERIVE_ATTR)), _ => None, }; attrs - .map(|attrs| attrs.structurize(db)) .unwrap_or_default() .iter() + .map(|attr| attr.clone().structurize(db)) + .flat_map(|attr| attr.args.into_iter()) .filter_map(|attr| { - self.find_expansion(&Expansion::new(attr.id.clone(), ExpansionKind::Attr)) + let AttributeArgVariant::Unnamed { value, .. } = attr.clone().variant else { + return None; + }; + let Expr::Path(path) = value else { + return None; + }; + let path = path.elements(db); + let path = path.last()?; + let PathSegment::Simple(segment) = path else { + return None; + }; + let ident = segment.ident(db); + let value = ident.text(db).to_string(); + + self.find_expansion(&Expansion::new( + camel_to_snake(value), + ExpansionKind::Derive, + )) }) .collect_vec() } - /// Handle `#[derive(...)]` attribute. - fn handle_derive(&self, _db: &dyn SyntaxGroup, _item_ast: ast::ModuleItem) -> Vec { - // Todo(maciektr): Implement. - Vec::new() + fn expand_derives( + &self, + db: &dyn SyntaxGroup, + item_ast: ast::ModuleItem, + stream_metadata: TokenStreamMetadata, + ) -> Option { + let stable_ptr = item_ast.clone().stable_ptr().untyped(); + let token_stream = + TokenStream::from_item_ast(db, item_ast.clone()).with_metadata(stream_metadata.clone()); + + let mut aux_data = EmittedAuxData::default(); + let mut all_diagnostics: Vec = Vec::new(); + + // All derives to be applied. + let derives = self.parse_derive(db, item_ast.clone()); + let any_derives = !derives.is_empty(); + + let mut derived_code = PatchBuilder::new(db); + for derive in derives { + let result = self + .instance(derive.package_id) + .generate_code(derive.expansion.name.clone(), token_stream.clone()); + + // Register diagnostics. + all_diagnostics.extend(result.diagnostics); + + // Register aux data. + if let Some(new_aux_data) = result.aux_data { + aux_data.push(ProcMacroAuxData::new( + new_aux_data.into(), + ProcMacroId::new(derive.package_id, derive.expansion.clone()), + )); + } + + if result.token_stream.is_empty() { + // No code has been generated. + // We do not need to do anything. + continue; + } + + derived_code.add_str(result.token_stream.to_string().as_str()); + } + + if any_derives { + return Some(PluginResult { + code: if derived_code.code.is_empty() { + None + } else { + Some(PluginGeneratedFile { + name: "proc_macro_derive".into(), + content: derived_code.code.to_string(), + code_mappings: Default::default(), + aux_data: if aux_data.is_empty() { + None + } else { + Some(DynGeneratedFileAuxData::new(aux_data)) + }, + }) + }, + diagnostics: into_cairo_diagnostics(all_diagnostics, stable_ptr), + // Note that we don't remove the original item here, unlike for attributes. + // We do not add the original code to the generated file either. + remove_original_item: false, + }); + } + + None + } + + fn expand_attribute( + &self, + input: ProcMacroId, + token_stream: TokenStream, + stable_ptr: SyntaxStablePtrId, + ) -> PluginResult { + let result = self + .instance(input.package_id) + .generate_code(input.expansion.name.clone(), token_stream.clone()); + + // Handle token stream. + if result.token_stream.is_empty() { + // Remove original code + return PluginResult { + diagnostics: into_cairo_diagnostics(result.diagnostics, stable_ptr), + code: None, + remove_original_item: true, + }; + } + + // Full path markers require code modification. + self.register_full_path_markers(input.package_id, result.full_path_markers.clone()); + + let file_name = format!("proc_macro_{}", input.expansion.name); + PluginResult { + code: Some(PluginGeneratedFile { + name: file_name.into(), + content: result.token_stream.to_string(), + code_mappings: Default::default(), + aux_data: result.aux_data.map(|new_aux_data| { + DynGeneratedFileAuxData::new(EmittedAuxData::new(ProcMacroAuxData::new( + new_aux_data.into(), + ProcMacroId::new(input.package_id, input.expansion.clone()), + ))) + }), + }), + diagnostics: into_cairo_diagnostics(result.diagnostics, stable_ptr), + remove_original_item: true, + } } fn find_expansion(&self, expansion: &Expansion) -> Option { @@ -328,6 +544,22 @@ impl ProcMacroHostPlugin { .find(|m| m.package_id() == package_id) .expect("procedural macro must be registered in proc macro host") } + + fn register_full_path_markers(&self, package_id: PackageId, markers: Vec) { + self.full_path_markers + .write() + .unwrap() + .entry(package_id) + .and_modify(|markers| markers.extend(markers.clone())) + .or_insert(markers); + } + + fn calculate_metadata(db: &dyn SyntaxGroup, item_ast: ast::ModuleItem) -> TokenStreamMetadata { + let stable_ptr = item_ast.clone().stable_ptr().untyped(); + let file_path = stable_ptr.file_id(db).full_path(db.upcast()); + let file_id = short_hash(file_path.clone()); + TokenStreamMetadata::new(file_path, file_id) + } } impl MacroPlugin for ProcMacroHostPlugin { @@ -337,76 +569,28 @@ impl MacroPlugin for ProcMacroHostPlugin { item_ast: ast::ModuleItem, _metadata: &MacroPluginMetadata<'_>, ) -> PluginResult { - // Apply expansion to `item_ast` where needed. - let expansions = self - .handle_attribute(db, item_ast.clone()) - .into_iter() - .chain(self.handle_derive(db, item_ast.clone())); - let stable_ptr = item_ast.clone().stable_ptr().untyped(); - let file_path = stable_ptr.file_id(db).full_path(db.upcast()); - let file_id = short_hash(file_path.clone()); - - let mut token_stream = TokenStream::from_item_ast(db, item_ast) - .with_metadata(TokenStreamMetadata::new(file_path, file_id)); - let mut aux_data = EmittedAuxData::default(); - let mut modified = false; - let mut all_diagnostics: Vec = Vec::new(); - for input in expansions { - let result = self - .instance(input.package_id) - .generate_code(input.expansion.name.clone(), token_stream.clone()); + let stream_metadata = Self::calculate_metadata(db, item_ast.clone()); + + // Expand first attribute. + // Note that we only expand the first attribute, as we assume that the rest of the attributes + // will be handled by a subsequent call to this function. + if let (Some(input), token_stream) = self.parse_attribute(db, item_ast.clone()) { + let token_stream = token_stream.with_metadata(stream_metadata.clone()); + let stable_ptr = item_ast.clone().stable_ptr().untyped(); + return self.expand_attribute(input, token_stream, stable_ptr); + } - // Handle diagnostics. - all_diagnostics.extend(result.diagnostics); - // Handle aux data. - if let Some(new_aux_data) = result.aux_data { - aux_data.push(ProcMacroAuxData::new( - new_aux_data.into(), - ProcMacroId::new(input.package_id, input.expansion.clone()), - )); - } - // Handle token stream. - let new_token_stream = result.token_stream.clone(); - if new_token_stream.is_empty() { - // Remove original code - return PluginResult { - diagnostics: into_cairo_diagnostics(all_diagnostics, stable_ptr), - code: None, - remove_original_item: true, - }; - } - // Replace original code. - modified = new_token_stream.to_string() != token_stream.to_string(); - token_stream = new_token_stream; - // Full path markers require code modification. - self.full_path_markers - .write() - .unwrap() - .entry(input.package_id) - .and_modify(|markers| markers.extend(result.full_path_markers.clone().into_iter())) - .or_insert(result.full_path_markers); + // Expand all derives. + // Note that all proc macro attributes should be already expanded at this point. + if let Some(result) = self.expand_derives(db, item_ast.clone(), stream_metadata.clone()) { + return result; } - if modified { - PluginResult { - code: Some(PluginGeneratedFile { - name: "proc_macro".into(), - content: token_stream.to_string(), - code_mappings: Default::default(), - aux_data: if aux_data.is_empty() { - None - } else { - Some(DynGeneratedFileAuxData::new(aux_data)) - }, - }), - diagnostics: into_cairo_diagnostics(all_diagnostics, stable_ptr), - remove_original_item: true, - } - } else { - PluginResult { - code: None, - diagnostics: into_cairo_diagnostics(all_diagnostics, stable_ptr), - remove_original_item: false, - } + + // No expansions can be applied. + PluginResult { + code: None, + diagnostics: Vec::new(), + remove_original_item: false, } } @@ -529,3 +713,35 @@ impl ProcMacroHost { ProcMacroHostPlugin::try_new(self.macros) } } + +fn camel_to_snake(name: String) -> String { + let mut result = String::with_capacity(name.len()); + for (i, c) in name.chars().enumerate() { + if c.is_uppercase() { + if i > 0 { + result.push('_'); + } + result.push(c.to_ascii_lowercase()); + } else { + result.push(c); + } + } + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_camel_to_snake() { + assert_eq!(camel_to_snake("CamelCase".to_string()), "camel_case"); + assert_eq!(camel_to_snake("Camel".to_string()), "camel"); + assert_eq!(camel_to_snake("camel".to_string()), "camel"); + assert_eq!(camel_to_snake("CAMEL".to_string()), "c_a_m_e_l"); + assert_eq!( + camel_to_snake("CamelCaseCase".to_string()), + "camel_case_case" + ); + } +} diff --git a/scarb/tests/build_cairo_plugin.rs b/scarb/tests/build_cairo_plugin.rs index 07eacc861..5bac4e43b 100644 --- a/scarb/tests/build_cairo_plugin.rs +++ b/scarb/tests/build_cairo_plugin.rs @@ -845,17 +845,14 @@ fn can_resolve_full_path_markers() { #[attribute_macro] pub fn some(token_stream: TokenStream) -> ProcMacroResult { - let token_stream = TokenStream::new( - token_stream - .to_string() - // Remove macro call to avoid infinite loop. - .replace("#[some]", r#"#[macro::full_path_marker("some-key")]"#) - .replace("12", "34") - ); - let full_path_markers = vec!["some-key".to_string()]; - ProcMacroResult::new(token_stream) + let code = format!( + r#"#[macro::full_path_marker("some-key")] {}"#, + token_stream.to_string().replace("12", "34") + ); + + ProcMacroResult::new(TokenStream::new(code)) .with_full_path_markers(full_path_markers) } @@ -978,3 +975,156 @@ fn empty_inline_macro_result() { error: could not compile `hello` due to previous error "#}); } + +#[test] +fn can_implement_derive_macro() { + let temp = TempDir::new().unwrap(); + let t = temp.child("some"); + CairoPluginProjectBuilder::default() + .lib_rs(indoc! {r##" + use cairo_lang_macro::{derive_macro, ProcMacroResult, TokenStream}; + + #[derive_macro] + pub fn custom_derive(token_stream: TokenStream) -> ProcMacroResult { + let name = token_stream + .clone() + .to_string() + .lines() + .find(|l| l.starts_with("struct")) + .unwrap() + .to_string() + .replace("struct", "") + .replace("}", "") + .replace("{", "") + .trim() + .to_string(); + + let token_stream = TokenStream::new(indoc::formatdoc!{r#" + impl SomeImpl of Hello<{name}> {{ + fn world(self: @{name}) -> u32 {{ + 32 + }} + }} + "#}); + + ProcMacroResult::new(token_stream) + } + "##}) + .add_dep(r#"indoc = "*""#) + .build(&t); + + let project = temp.child("hello"); + ProjectBuilder::start() + .name("hello") + .version("1.0.0") + .dep("some", &t) + .lib_cairo(indoc! {r#" + trait Hello { + fn world(self: @T) -> u32; + } + + #[derive(CustomDerive, Drop)] + struct SomeType {} + + fn main() -> u32 { + let a = SomeType {}; + a.world() + } + "#}) + .build(&project); + + Scarb::quick_snapbox() + .arg("cairo-run") + // Disable output from Cargo. + .env("CARGO_TERM_QUIET", "true") + .current_dir(&project) + .assert() + .success() + .stdout_matches(indoc! {r#" + [..] Compiling some v1.0.0 ([..]Scarb.toml) + [..] Compiling hello v1.0.0 ([..]Scarb.toml) + [..]Finished release target(s) in [..] + [..]Running hello + Run completed successfully, returning [32] + "#}); +} + +#[test] +fn can_use_both_derive_and_attr() { + let temp = TempDir::new().unwrap(); + let t = temp.child("some"); + CairoPluginProjectBuilder::default() + .lib_rs(indoc! {r##" + use cairo_lang_macro::{derive_macro, attribute_macro, ProcMacroResult, TokenStream}; + + #[attribute_macro] + pub fn first_attribute(token_stream: TokenStream) -> ProcMacroResult { + ProcMacroResult::new(TokenStream::new( + token_stream.to_string() + .replace("SomeType", "OtherType") + )) + } + + #[attribute_macro] + pub fn second_attribute(token_stream: TokenStream) -> ProcMacroResult { + let token_stream = TokenStream::new( + token_stream.to_string().replace("OtherType", "RenamedStruct") + ); + ProcMacroResult::new(TokenStream::new( + format!("#[derive(Drop)]\n{token_stream}") + )) + } + + #[derive_macro] + pub fn custom_derive(_token_stream: TokenStream) -> ProcMacroResult { + ProcMacroResult::new(TokenStream::new( + indoc::formatdoc!{r#" + impl SomeImpl of Hello {{ + fn world(self: @RenamedStruct) -> u32 {{ + 32 + }} + }} + "#} + )) + } + "##}) + .add_dep(r#"indoc = "*""#) + .build(&t); + + let project = temp.child("hello"); + ProjectBuilder::start() + .name("hello") + .version("1.0.0") + .dep("some", &t) + .lib_cairo(indoc! {r#" + trait Hello { + fn world(self: @T) -> u32; + } + + #[first_attribute] + #[derive(CustomDerive)] + #[second_attribute] + struct SomeType {} + + fn main() -> u32 { + let a = RenamedStruct {}; + a.world() + } + "#}) + .build(&project); + + Scarb::quick_snapbox() + .arg("cairo-run") + // Disable output from Cargo. + .env("CARGO_TERM_QUIET", "true") + .current_dir(&project) + .assert() + .success() + .stdout_matches(indoc! {r#" + [..] Compiling some v1.0.0 ([..]Scarb.toml) + [..] Compiling hello v1.0.0 ([..]Scarb.toml) + [..]Finished release target(s) in [..] + [..]Running hello + Run completed successfully, returning [32] + "#}); +}