diff --git a/scarb/src/compiler/db.rs b/scarb/src/compiler/db.rs index 06f0f0d79..5931e8cd1 100644 --- a/scarb/src/compiler/db.rs +++ b/scarb/src/compiler/db.rs @@ -11,33 +11,40 @@ use smol_str::SmolStr; use std::sync::Arc; use tracing::trace; -use crate::compiler::plugin::proc_macro::ProcMacroHost; +use crate::compiler::plugin::proc_macro::{ProcMacroHost, ProcMacroHostPlugin}; use crate::compiler::{CairoCompilationUnit, CompilationUnitAttributes, CompilationUnitComponent}; use crate::core::Workspace; use crate::DEFAULT_MODULE_MAIN_FILE; -// TODO(mkaput): ScarbDatabase? +pub struct ScarbDatabase { + pub db: RootDatabase, + pub proc_macro_host: Arc, +} + pub(crate) fn build_scarb_root_database( unit: &CairoCompilationUnit, ws: &Workspace<'_>, -) -> Result { +) -> Result { let mut b = RootDatabase::builder(); b.with_project_config(build_project_config(unit)?); b.with_cfg(unit.cfg_set.clone()); - load_plugins(unit, ws, &mut b)?; + let proc_macro_host = load_plugins(unit, ws, &mut b)?; if !unit.compiler_config.enable_gas { b.skip_auto_withdraw_gas(); } let mut db = b.build()?; inject_virtual_wrapper_lib(&mut db, unit)?; - Ok(db) + Ok(ScarbDatabase { + db, + proc_macro_host, + }) } fn load_plugins( unit: &CairoCompilationUnit, ws: &Workspace<'_>, builder: &mut RootDatabaseBuilder, -) -> Result<()> { +) -> Result> { let mut proc_macros = ProcMacroHost::default(); for plugin_info in &unit.cairo_plugins { if plugin_info.builtin { @@ -49,8 +56,9 @@ fn load_plugins( proc_macros.register(plugin_info.package.clone(), ws.config())?; } } - builder.with_plugin_suite(proc_macros.into_plugin_suite()); - Ok(()) + let macro_host = Arc::new(proc_macros.into_plugin()); + builder.with_plugin_suite(ProcMacroHostPlugin::build_plugin_suite(macro_host.clone())); + Ok(macro_host) } /// Generates a wrapper lib file for appropriate compilation units. diff --git a/scarb/src/compiler/plugin/proc_macro/host.rs b/scarb/src/compiler/plugin/proc_macro/host.rs index 4d3c36e9b..6e29f8615 100644 --- a/scarb/src/compiler/plugin/proc_macro/host.rs +++ b/scarb/src/compiler/plugin/proc_macro/host.rs @@ -1,6 +1,7 @@ use crate::compiler::plugin::proc_macro::{FromItemAst, ProcMacroInstance}; use crate::core::{Config, Package, PackageId}; use anyhow::Result; +use cairo_lang_defs::db::DefsGroup; use cairo_lang_defs::plugin::PluginDiagnostic; use cairo_lang_defs::plugin::{ DynGeneratedFileAuxData, GeneratedFileAuxData, MacroPlugin, MacroPluginMetadata, @@ -47,11 +48,25 @@ pub struct ProcMacroInput { } #[derive(Clone, Debug)] -pub struct ProcMacroAuxData(String); +pub struct ProcMacroAuxData { + value: String, + macro_id: ProcMacroId, + macro_package_id: PackageId, +} + +impl ProcMacroAuxData { + pub fn new(value: String, macro_id: ProcMacroId, macro_package_id: PackageId) -> Self { + Self { + value, + macro_id, + macro_package_id, + } + } +} impl From for AuxData { fn from(data: ProcMacroAuxData) -> Self { - Self::new(data.0) + Self::new(data.value) } } @@ -61,7 +76,8 @@ impl GeneratedFileAuxData for ProcMacroAuxData { } fn eq(&self, other: &dyn GeneratedFileAuxData) -> bool { - self.0 == other.as_any().downcast_ref::().unwrap().0 + self.value == other.as_any().downcast_ref::().unwrap().value + && self.macro_id == other.as_any().downcast_ref::().unwrap().macro_id } } @@ -128,6 +144,36 @@ impl ProcMacroHostPlugin { .find(|m| m.declared_attributes().contains(&name)) .map(|m| m.package_id()) } + + pub fn build_plugin_suite(macr_host: Arc) -> PluginSuite { + let mut suite = PluginSuite::default(); + suite.add_plugin_ex(macr_host); + suite + } + + #[tracing::instrument(level = "trace", skip_all)] + pub fn collect_aux_data(&self, db: &dyn DefsGroup) -> Result<()> { + let mut data = Vec::new(); + for crate_id in db.crates() { + let crate_modules = db.crate_modules(crate_id); + for module in crate_modules.iter() { + let file_infos = db.module_generated_file_infos(*module); + if let Ok(file_infos) = file_infos { + for file_info in file_infos.iter().flatten() { + let aux_data = file_info + .aux_data + .as_ref() + .and_then(|ad| ad.as_any().downcast_ref::()); + if let Some(aux_data) = aux_data { + data.push(aux_data.clone()); + } + } + } + } + } + let _aux_data = data.into_iter().into_group_map_by(|d| d.macro_package_id); + Ok(()) + } } impl MacroPlugin for ProcMacroHostPlugin { @@ -146,7 +192,7 @@ impl MacroPlugin for ProcMacroHostPlugin { let stable_ptr = item_ast.clone().stable_ptr().untyped(); let mut token_stream = TokenStream::from_item_ast(db, item_ast); - let mut aux_data: Option = None; + let mut aux_data: Option = None; let mut modified = false; let mut all_diagnostics: Vec = Vec::new(); for input in expansions { @@ -162,7 +208,13 @@ impl MacroPlugin for ProcMacroHostPlugin { diagnostics, } => { token_stream = new_token_stream; - aux_data = new_aux_data; + if let Some(new_aux_data) = new_aux_data { + aux_data = Some(ProcMacroAuxData::new( + new_aux_data.to_string(), + input.id, + input.macro_package_id, + )); + } modified = true; all_diagnostics.extend(diagnostics); } @@ -185,8 +237,7 @@ impl MacroPlugin for ProcMacroHostPlugin { name: "proc_macro".into(), content: token_stream.to_string(), code_mappings: Default::default(), - aux_data: aux_data - .map(|ad| DynGeneratedFileAuxData::new(ProcMacroAuxData(ad.to_string()))), + aux_data: aux_data.map(DynGeneratedFileAuxData::new), }), diagnostics: into_cairo_diagnostics(all_diagnostics, stable_ptr), remove_original_item: true, @@ -241,10 +292,7 @@ impl ProcMacroHost { Ok(()) } - pub fn into_plugin_suite(self) -> PluginSuite { - let macro_host = ProcMacroHostPlugin::new(self.macros); - let mut suite = PluginSuite::default(); - suite.add_plugin_ex(Arc::new(macro_host)); - suite + pub fn into_plugin(self) -> ProcMacroHostPlugin { + ProcMacroHostPlugin::new(self.macros) } } diff --git a/scarb/src/ops/compile.rs b/scarb/src/ops/compile.rs index d4fbf310d..eab526431 100644 --- a/scarb/src/ops/compile.rs +++ b/scarb/src/ops/compile.rs @@ -1,13 +1,14 @@ -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use cairo_lang_compiler::db::RootDatabase; use cairo_lang_compiler::diagnostics::DiagnosticsError; +use cairo_lang_utils::Upcast; use indoc::formatdoc; use itertools::Itertools; use scarb_ui::components::Status; use scarb_ui::HumanDuration; -use crate::compiler::db::{build_scarb_root_database, has_starknet_plugin}; +use crate::compiler::db::{build_scarb_root_database, has_starknet_plugin, ScarbDatabase}; use crate::compiler::helpers::build_compiler_config; use crate::compiler::plugin::proc_macro; use crate::compiler::{CairoCompilationUnit, CompilationUnit, CompilationUnitAttributes}; @@ -104,9 +105,16 @@ fn compile_unit(unit: CompilationUnit, ws: &Workspace<'_>) -> Result<()> { let result = match unit { CompilationUnit::ProcMacro(unit) => proc_macro::compile_unit(unit, ws), CompilationUnit::Cairo(unit) => { - let mut db = build_scarb_root_database(&unit, ws)?; + let ScarbDatabase { + mut db, + proc_macro_host, + } = build_scarb_root_database(&unit, ws)?; check_starknet_dependency(&unit, ws, &db, &package_name); - ws.config().compilers().compile(unit, &mut db, ws) + let result = ws.config().compilers().compile(unit, &mut db, ws); + proc_macro_host + .collect_aux_data(db.upcast()) + .context("procedural macro auxiliary data callback call failed")?; + result } }; @@ -126,28 +134,26 @@ fn check_unit(unit: CompilationUnit, ws: &Workspace<'_>) -> Result<()> { .ui() .print(Status::new("Checking", &unit.name())); - match unit { - CompilationUnit::ProcMacro(unit) => proc_macro::check_unit(unit, ws)?, + let result = match unit { + CompilationUnit::ProcMacro(unit) => proc_macro::check_unit(unit, ws), CompilationUnit::Cairo(unit) => { - let db = build_scarb_root_database(&unit, ws)?; - + let ScarbDatabase { db, .. } = build_scarb_root_database(&unit, ws)?; check_starknet_dependency(&unit, ws, &db, &package_name); - let mut compiler_config = build_compiler_config(&unit, ws); - compiler_config .diagnostics_reporter .ensure(&db) - .map_err(|err| { - let valid_error = err.into(); - if !suppress_error(&valid_error) { - ws.config().ui().anyhow(&valid_error); - } - - anyhow!("could not check `{package_name}` due to previous error") - })?; + .map_err(|err| err.into()) } - } + }; + + result.map_err(|err| { + if !suppress_error(&err) { + ws.config().ui().anyhow(&err); + } + + anyhow!("could not check `{package_name}` due to previous error") + })?; Ok(()) }