diff --git a/components/salsa-macro-rules/src/setup_tracked_fn.rs b/components/salsa-macro-rules/src/setup_tracked_fn.rs index 37cf48e46..6d1cd07af 100644 --- a/components/salsa-macro-rules/src/setup_tracked_fn.rs +++ b/components/salsa-macro-rules/src/setup_tracked_fn.rs @@ -55,6 +55,9 @@ macro_rules! setup_tracked_fn { // If true, the input needs an interner (because it has >1 argument). needs_interner: $needs_interner:tt, + // The function used to implement `C::heap_size`. + heap_size_fn: $($heap_size_fn:path)?, + // LRU capacity (a literal, maybe 0) lru: $lru:tt, @@ -196,6 +199,12 @@ macro_rules! setup_tracked_fn { $($values_equal)+ + $( + fn heap_size(value: &Self::Output<'_>) -> usize { + $heap_size_fn(value) + } + )? + fn execute<$db_lt>($db: &$db_lt Self::DbView, ($($input_id),*): ($($interned_input_ty),*)) -> Self::Output<$db_lt> { $($assert_return_type_is_update)* diff --git a/components/salsa-macros/src/accumulator.rs b/components/salsa-macros/src/accumulator.rs index d1890c04a..531ba83f4 100644 --- a/components/salsa-macros/src/accumulator.rs +++ b/components/salsa-macros/src/accumulator.rs @@ -45,6 +45,7 @@ impl AllowedOptions for Accumulator { const CONSTRUCTOR_NAME: bool = false; const ID: bool = false; const REVISIONS: bool = false; + const HEAP_SIZE: bool = false; } struct StructMacro { diff --git a/components/salsa-macros/src/input.rs b/components/salsa-macros/src/input.rs index 6b948d437..17df4267c 100644 --- a/components/salsa-macros/src/input.rs +++ b/components/salsa-macros/src/input.rs @@ -64,6 +64,8 @@ impl crate::options::AllowedOptions for InputStruct { const ID: bool = false; const REVISIONS: bool = false; + + const HEAP_SIZE: bool = false; } impl SalsaStructAllowedOptions for InputStruct { diff --git a/components/salsa-macros/src/interned.rs b/components/salsa-macros/src/interned.rs index 38287128c..27229167e 100644 --- a/components/salsa-macros/src/interned.rs +++ b/components/salsa-macros/src/interned.rs @@ -64,6 +64,8 @@ impl crate::options::AllowedOptions for InternedStruct { const ID: bool = true; const REVISIONS: bool = true; + + const HEAP_SIZE: bool = false; } impl SalsaStructAllowedOptions for InternedStruct { diff --git a/components/salsa-macros/src/options.rs b/components/salsa-macros/src/options.rs index c69be9791..e26c0163f 100644 --- a/components/salsa-macros/src/options.rs +++ b/components/salsa-macros/src/options.rs @@ -99,6 +99,12 @@ pub(crate) struct Options { /// This is stored as a `syn::Expr` to support `usize::MAX`. pub revisions: Option, + /// The `heap_size = ` option can be used to track heap memory usage of memoized + /// values. + /// + /// If this is `Some`, the value is the provided `heap_size` function. + pub heap_size_fn: Option, + /// Remember the `A` parameter, which plays no role after parsing. phantom: PhantomData, } @@ -123,6 +129,7 @@ impl Default for Options { singleton: Default::default(), id: Default::default(), revisions: Default::default(), + heap_size_fn: Default::default(), } } } @@ -145,6 +152,7 @@ pub(crate) trait AllowedOptions { const CONSTRUCTOR_NAME: bool; const ID: bool; const REVISIONS: bool; + const HEAP_SIZE: bool; } type Equals = syn::Token![=]; @@ -392,6 +400,22 @@ impl syn::parse::Parse for Options { "`revisions` option not allowed here", )); } + } else if ident == "heap_size" { + if A::HEAP_SIZE { + let _eq = Equals::parse(input)?; + let path = syn::Path::parse(input)?; + if let Some(old) = options.heap_size_fn.replace(path) { + return Err(syn::Error::new( + old.span(), + "option `heap_size` provided twice", + )); + } + } else { + return Err(syn::Error::new( + ident.span(), + "`heap_size` option not allowed here", + )); + } } else { return Err(syn::Error::new( ident.span(), diff --git a/components/salsa-macros/src/tracked_fn.rs b/components/salsa-macros/src/tracked_fn.rs index ad7b869a3..7d0380fa2 100644 --- a/components/salsa-macros/src/tracked_fn.rs +++ b/components/salsa-macros/src/tracked_fn.rs @@ -57,6 +57,8 @@ impl crate::options::AllowedOptions for TrackedFn { const ID: bool = false; const REVISIONS: bool = false; + + const HEAP_SIZE: bool = true; } struct Macro { @@ -97,6 +99,7 @@ impl Macro { self.cycle_recovery()?; let is_specifiable = self.args.specify.is_some(); let requires_update = self.args.non_update_return_type.is_none(); + let heap_size_fn = self.args.heap_size_fn.iter(); let eq = if let Some(token) = &self.args.no_eq { if self.args.cycle_fn.is_some() { return Err(syn::Error::new_spanned( @@ -217,6 +220,7 @@ impl Macro { is_specifiable: #is_specifiable, values_equal: {#eq}, needs_interner: #needs_interner, + heap_size_fn: #(#heap_size_fn)*, lru: #lru, return_mode: #return_mode, assert_return_type_is_update: { #assert_return_type_is_update }, diff --git a/components/salsa-macros/src/tracked_struct.rs b/components/salsa-macros/src/tracked_struct.rs index 999082f43..c3d2b9211 100644 --- a/components/salsa-macros/src/tracked_struct.rs +++ b/components/salsa-macros/src/tracked_struct.rs @@ -60,6 +60,8 @@ impl crate::options::AllowedOptions for TrackedStruct { const ID: bool = false; const REVISIONS: bool = false; + + const HEAP_SIZE: bool = false; } impl SalsaStructAllowedOptions for TrackedStruct { diff --git a/src/database.rs b/src/database.rs index a92b57913..594deb0a1 100644 --- a/src/database.rs +++ b/src/database.rs @@ -135,7 +135,10 @@ impl dyn Database { } #[cfg(feature = "salsa_unstable")] -pub use memory_usage::{IngredientInfo, SlotInfo}; +pub use memory_usage::IngredientInfo; + +#[cfg(feature = "salsa_unstable")] +pub(crate) use memory_usage::{MemoInfo, SlotInfo}; #[cfg(feature = "salsa_unstable")] mod memory_usage { @@ -171,8 +174,8 @@ mod memory_usage { /// Returns information about any memoized Salsa queries. /// /// The returned map holds memory usage information for memoized values of a given query, keyed - /// by its `(input, output)` type names. - pub fn queries_info(&self) -> HashMap<(&'static str, &'static str), IngredientInfo> { + /// by the query function name. + pub fn queries_info(&self) -> HashMap<&'static str, IngredientInfo> { let mut queries = HashMap::new(); for input_ingredient in self.zalsa().ingredients() { @@ -181,17 +184,15 @@ mod memory_usage { }; for input in input_info { - for output in input.memos { - let info = queries - .entry((input.debug_name, output.debug_name)) - .or_insert(IngredientInfo { - debug_name: output.debug_name, - ..Default::default() - }); + for memo in input.memos { + let info = queries.entry(memo.debug_name).or_insert(IngredientInfo { + debug_name: memo.output.debug_name, + ..Default::default() + }); info.count += 1; - info.size_of_fields += output.size_of_fields; - info.size_of_metadata += output.size_of_metadata; + info.size_of_fields += memo.output.size_of_fields; + info.size_of_metadata += memo.output.size_of_metadata; } } } @@ -236,6 +237,12 @@ mod memory_usage { pub(crate) debug_name: &'static str, pub(crate) size_of_metadata: usize, pub(crate) size_of_fields: usize, - pub(crate) memos: Vec, + pub(crate) memos: Vec, + } + + /// Memory usage information about a particular memo. + pub struct MemoInfo { + pub(crate) debug_name: &'static str, + pub(crate) output: SlotInfo, } } diff --git a/src/function.rs b/src/function.rs index be88f7e39..76b2abf7d 100644 --- a/src/function.rs +++ b/src/function.rs @@ -36,7 +36,7 @@ mod memo; mod specify; mod sync; -pub type Memo = memo::Memo<::Output<'static>>; +pub type Memo = memo::Memo<'static, C>; pub trait Configuration: Any { const DEBUG_NAME: &'static str; @@ -72,6 +72,11 @@ pub trait Configuration: Any { /// This is a no-op if the input to the function is a salsa struct. fn id_to_input(db: &Self::DbView, key: Id) -> Self::Input<'_>; + /// Returns the size of any heap allocations in the output value, in bytes. + fn heap_size(_value: &Self::Output<'_>) -> usize { + 0 + } + /// Invoked when we need to compute the value for the given key, either because we've never /// computed it before or because the old one relied on inputs that have changed. /// @@ -181,8 +186,8 @@ where /// only cleared with `&mut self`. unsafe fn extend_memo_lifetime<'this>( &'this self, - memo: &memo::Memo>, - ) -> &'this memo::Memo> { + memo: &memo::Memo<'this, C>, + ) -> &'this memo::Memo<'this, C> { // SAFETY: the caller must guarantee that the memo will not be released before `&self` unsafe { std::mem::transmute(memo) } } @@ -191,9 +196,9 @@ where &'db self, zalsa: &'db Zalsa, id: Id, - mut memo: memo::Memo>, + mut memo: memo::Memo<'db, C>, memo_ingredient_index: MemoIngredientIndex, - ) -> &'db memo::Memo> { + ) -> &'db memo::Memo<'db, C> { if let Some(tracked_struct_ids) = memo.revisions.tracked_struct_ids_mut() { tracked_struct_ids.shrink_to_fit(); } diff --git a/src/function/backdate.rs b/src/function/backdate.rs index 5ec652c7b..873041597 100644 --- a/src/function/backdate.rs +++ b/src/function/backdate.rs @@ -12,7 +12,7 @@ where /// on an old memo when a new memo has been produced to check whether there have been changed. pub(super) fn backdate_if_appropriate<'db>( &self, - old_memo: &Memo>, + old_memo: &Memo<'db, C>, index: DatabaseKeyIndex, revisions: &mut QueryRevisions, value: &C::Output<'db>, diff --git a/src/function/delete.rs b/src/function/delete.rs index 77b5c0564..d061917b0 100644 --- a/src/function/delete.rs +++ b/src/function/delete.rs @@ -7,7 +7,7 @@ use crate::function::Configuration; /// once the next revision starts. See the comment on the field /// `deleted_entries` of [`FunctionIngredient`][] for more details. pub(super) struct DeletedEntries { - memos: boxcar::Vec>>>, + memos: boxcar::Vec>>, } #[allow(clippy::undocumented_unsafe_blocks)] // TODO(#697) document safety @@ -27,13 +27,10 @@ impl DeletedEntries { /// # Safety /// /// The memo must be valid and safe to free when the `DeletedEntries` list is cleared or dropped. - pub(super) unsafe fn push(&self, memo: NonNull>>) { + pub(super) unsafe fn push(&self, memo: NonNull>) { // Safety: The memo must be valid and safe to free when the `DeletedEntries` list is cleared or dropped. - let memo = unsafe { - std::mem::transmute::>>, NonNull>>>( - memo, - ) - }; + let memo = + unsafe { std::mem::transmute::>, NonNull>>(memo) }; self.memos.push(SharedBox(memo)); } diff --git a/src/function/diff_outputs.rs b/src/function/diff_outputs.rs index 74e4a7fbb..b1d17b75a 100644 --- a/src/function/diff_outputs.rs +++ b/src/function/diff_outputs.rs @@ -18,7 +18,7 @@ where &self, zalsa: &Zalsa, key: DatabaseKeyIndex, - old_memo: &Memo>, + old_memo: &Memo<'_, C>, revisions: &mut QueryRevisions, ) { let (QueryOriginRef::Derived(edges) | QueryOriginRef::DerivedUntracked(edges)) = diff --git a/src/function/execute.rs b/src/function/execute.rs index b2166d8cb..13ecca561 100644 --- a/src/function/execute.rs +++ b/src/function/execute.rs @@ -24,8 +24,8 @@ where &'db self, db: &'db C::DbView, active_query: ActiveQueryGuard<'db>, - opt_old_memo: Option<&Memo>>, - ) -> &'db Memo> { + opt_old_memo: Option<&Memo<'db, C>>, + ) -> &'db Memo<'db, C> { let database_key_index = active_query.database_key_index; let id = database_key_index.key_index(); @@ -121,7 +121,7 @@ where &'db self, db: &'db C::DbView, mut active_query: ActiveQueryGuard<'db>, - opt_old_memo: Option<&Memo>>, + opt_old_memo: Option<&Memo<'db, C>>, zalsa: &'db Zalsa, id: Id, memo_ingredient_index: MemoIngredientIndex, @@ -133,7 +133,7 @@ where // Our provisional value from the previous iteration, when doing fixpoint iteration. // Initially it's set to None, because the initial provisional value is created lazily, // only when a cycle is actually encountered. - let mut opt_last_provisional: Option<&Memo<::Output<'db>>> = None; + let mut opt_last_provisional: Option<&Memo<'db, C>> = None; loop { let previous_memo = opt_last_provisional.or(opt_old_memo); let (mut new_value, mut revisions) = Self::execute_query( @@ -257,7 +257,7 @@ where fn execute_query<'db>( db: &'db C::DbView, active_query: ActiveQueryGuard<'db>, - opt_old_memo: Option<&Memo>>, + opt_old_memo: Option<&Memo<'db, C>>, current_revision: Revision, id: Id, ) -> (C::Output<'db>, QueryRevisions) { diff --git a/src/function/fetch.rs b/src/function/fetch.rs index bfd5ffedc..6c7819f81 100644 --- a/src/function/fetch.rs +++ b/src/function/fetch.rs @@ -43,7 +43,7 @@ where zalsa: &'db Zalsa, zalsa_local: &'db ZalsaLocal, id: Id, - ) -> &'db Memo> { + ) -> &'db Memo<'db, C> { let memo_ingredient_index = self.memo_ingredient_index(zalsa, id); loop { if let Some(memo) = self @@ -63,7 +63,7 @@ where zalsa: &'db Zalsa, id: Id, memo_ingredient_index: MemoIngredientIndex, - ) -> Option<&'db Memo>> { + ) -> Option<&'db Memo<'db, C>> { let memo = self.get_memo_from_table_for(zalsa, id, memo_ingredient_index)?; memo.value.as_ref()?; @@ -91,7 +91,7 @@ where db: &'db C::DbView, id: Id, memo_ingredient_index: MemoIngredientIndex, - ) -> Option<&'db Memo>> { + ) -> Option<&'db Memo<'db, C>> { let memo = self.fetch_cold(zalsa, zalsa_local, db, id, memo_ingredient_index)?; // If we get back a provisional cycle memo, and it's provisional on any cycle heads @@ -117,7 +117,7 @@ where db: &'db C::DbView, id: Id, memo_ingredient_index: MemoIngredientIndex, - ) -> Option<&'db Memo>> { + ) -> Option<&'db Memo<'db, C>> { let database_key_index = self.database_key_index(id); // Try to claim this query: if someone else has claimed it already, go back and start again. let claim_guard = match self.sync_table.try_claim(zalsa, id) { diff --git a/src/function/maybe_changed_after.rs b/src/function/maybe_changed_after.rs index 96f5eae5f..9d0ca4c44 100644 --- a/src/function/maybe_changed_after.rs +++ b/src/function/maybe_changed_after.rs @@ -192,7 +192,7 @@ where &self, zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, - memo: &Memo>, + memo: &Memo<'_, C>, ) -> ShallowUpdate { tracing::debug!( "{database_key_index:?}: shallow_verify_memo(memo = {memo:#?})", @@ -227,7 +227,7 @@ where &self, zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, - memo: &Memo>, + memo: &Memo<'_, C>, update: ShallowUpdate, ) { if let ShallowUpdate::HigherDurability = update { @@ -247,7 +247,7 @@ where zalsa: &Zalsa, zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, - memo: &Memo>, + memo: &Memo<'_, C>, ) -> bool { !memo.may_be_provisional() || self.validate_provisional(zalsa, database_key_index, memo) @@ -261,7 +261,7 @@ where &self, zalsa: &Zalsa, database_key_index: DatabaseKeyIndex, - memo: &Memo>, + memo: &Memo<'_, C>, ) -> bool { tracing::trace!( "{database_key_index:?}: validate_provisional(memo = {memo:#?})", @@ -322,7 +322,7 @@ where zalsa: &Zalsa, zalsa_local: &ZalsaLocal, database_key_index: DatabaseKeyIndex, - memo: &Memo>, + memo: &Memo<'_, C>, ) -> bool { tracing::trace!( "{database_key_index:?}: validate_same_iteration(memo = {memo:#?})", @@ -373,7 +373,7 @@ where &self, db: &C::DbView, zalsa: &Zalsa, - old_memo: &Memo>, + old_memo: &Memo<'_, C>, database_key_index: DatabaseKeyIndex, cycle_heads: &mut CycleHeads, ) -> VerifyResult { diff --git a/src/function/memo.rs b/src/function/memo.rs index 77efe8d6b..8f8952e5b 100644 --- a/src/function/memo.rs +++ b/src/function/memo.rs @@ -22,23 +22,20 @@ impl IngredientImpl { &self, zalsa: &'db Zalsa, id: Id, - memo: NonNull>>, + memo: NonNull>, memo_ingredient_index: MemoIngredientIndex, - ) -> Option>>> { + ) -> Option>> { // SAFETY: The table stores 'static memos (to support `Any`), the memos are in fact valid // for `'db` though as we delay their dropping to the end of a revision. - let static_memo = unsafe { - transmute::>>, NonNull>>>(memo) - }; + let static_memo = + unsafe { transmute::>, NonNull>>(memo) }; let old_static_memo = zalsa .memo_table_for(id) .insert(memo_ingredient_index, static_memo)?; // SAFETY: The table stores 'static memos (to support `Any`), the memos are in fact valid // for `'db` though as we delay their dropping to the end of a revision. Some(unsafe { - transmute::>>, NonNull>>>( - old_static_memo, - ) + transmute::>, NonNull>>(old_static_memo) }) } @@ -50,13 +47,11 @@ impl IngredientImpl { zalsa: &'db Zalsa, id: Id, memo_ingredient_index: MemoIngredientIndex, - ) -> Option<&'db Memo>> { + ) -> Option<&'db Memo<'db, C>> { let static_memo = zalsa.memo_table_for(id).get(memo_ingredient_index)?; // SAFETY: The table stores 'static memos (to support `Any`), the memos are in fact valid // for `'db` though as we delay their dropping to the end of a revision. - Some(unsafe { - transmute::<&Memo>, &'db Memo>>(static_memo.as_ref()) - }) + Some(unsafe { transmute::<&Memo<'static, C>, &'db Memo<'db, C>>(static_memo.as_ref()) }) } /// Evicts the existing memo for the given key, replacing it @@ -66,7 +61,7 @@ impl IngredientImpl { table: MemoTableWithTypesMut<'_>, memo_ingredient_index: MemoIngredientIndex, ) { - let map = |memo: &mut Memo>| { + let map = |memo: &mut Memo<'static, C>| { match memo.revisions.origin.as_ref() { QueryOriginRef::Assigned(_) | QueryOriginRef::DerivedUntracked(_) @@ -88,9 +83,9 @@ impl IngredientImpl { } #[derive(Debug)] -pub struct Memo { +pub struct Memo<'db, C: Configuration> { /// The result of the query, if we decide to memoize it. - pub(super) value: Option, + pub(super) value: Option>, /// Last revision when this memo was verified; this begins /// as the current revision. @@ -100,14 +95,12 @@ pub struct Memo { pub(super) revisions: QueryRevisions, } -// Memo's are stored a lot, make sure their size is doesn't randomly increase. -#[cfg(not(feature = "shuttle"))] -#[cfg(target_pointer_width = "64")] -const _: [(); std::mem::size_of::>()] = - [(); std::mem::size_of::<[usize; 6]>()]; - -impl Memo { - pub(super) fn new(value: Option, revision_now: Revision, revisions: QueryRevisions) -> Self { +impl<'db, C: Configuration> Memo<'db, C> { + pub(super) fn new( + value: Option>, + revision_now: Revision, + revisions: QueryRevisions, + ) -> Self { debug_assert!( !revisions.verified_final.load(Ordering::Relaxed) || revisions.cycle_heads().is_empty(), "Memo must be finalized if it has no cycle heads" @@ -286,12 +279,12 @@ impl Memo { } } - pub(super) fn tracing_debug(&self) -> impl std::fmt::Debug + use<'_, V> { - struct TracingDebug<'a, T> { - memo: &'a Memo, + pub(super) fn tracing_debug(&self) -> impl std::fmt::Debug + use<'_, 'db, C> { + struct TracingDebug<'memo, 'db, C: Configuration> { + memo: &'memo Memo<'db, C>, } - impl std::fmt::Debug for TracingDebug<'_, T> { + impl std::fmt::Debug for TracingDebug<'_, '_, C> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("Memo") .field( @@ -312,20 +305,27 @@ impl Memo { } } -impl crate::table::memo::Memo for Memo { +impl crate::table::memo::Memo for Memo<'static, C> +where + C::Output<'static>: Send + Sync + Any, +{ fn origin(&self) -> QueryOriginRef<'_> { self.revisions.origin.as_ref() } #[cfg(feature = "salsa_unstable")] - fn memory_usage(&self) -> crate::SlotInfo { - let size_of = std::mem::size_of::>() + self.revisions.allocation_size(); - - crate::SlotInfo { - size_of_metadata: size_of - std::mem::size_of::(), - debug_name: std::any::type_name::(), - size_of_fields: std::mem::size_of::(), - memos: Vec::new(), + fn memory_usage(&self) -> crate::database::MemoInfo { + let size_of = std::mem::size_of::>() + self.revisions.allocation_size(); + let heap_size = self.value.as_ref().map(C::heap_size).unwrap_or(0); + + crate::database::MemoInfo { + debug_name: C::DEBUG_NAME, + output: crate::database::SlotInfo { + size_of_metadata: size_of - std::mem::size_of::>(), + debug_name: std::any::type_name::>(), + size_of_fields: std::mem::size_of::>() + heap_size, + memos: Vec::new(), + }, } } } @@ -445,3 +445,70 @@ impl<'me> Iterator for TryClaimCycleHeadsIter<'me> { } } } + +#[cfg(all(not(feature = "shuttle"), target_pointer_width = "64"))] +mod _memory_usage { + use crate::cycle::CycleRecoveryStrategy; + use crate::ingredient::Location; + use crate::plumbing::{IngredientIndices, MemoIngredientSingletonIndex, SalsaStructInDb}; + use crate::zalsa::Zalsa; + use crate::{CycleRecoveryAction, Database, Id}; + + use std::any::TypeId; + use std::num::NonZeroUsize; + + // Memo's are stored a lot, make sure their size is doesn't randomly increase. + const _: [(); std::mem::size_of::>()] = + [(); std::mem::size_of::<[usize; 6]>()]; + + struct DummyStruct; + + impl SalsaStructInDb for DummyStruct { + type MemoIngredientMap = MemoIngredientSingletonIndex; + + fn lookup_or_create_ingredient_index(_: &Zalsa) -> IngredientIndices { + unimplemented!() + } + + fn cast(_: Id, _: TypeId) -> Option { + unimplemented!() + } + } + + struct DummyConfiguration; + + impl super::Configuration for DummyConfiguration { + const DEBUG_NAME: &'static str = ""; + const LOCATION: Location = Location { file: "", line: 0 }; + type DbView = dyn Database; + type SalsaStruct<'db> = DummyStruct; + type Input<'db> = (); + type Output<'db> = NonZeroUsize; + const CYCLE_STRATEGY: CycleRecoveryStrategy = CycleRecoveryStrategy::Panic; + + fn values_equal<'db>(_: &Self::Output<'db>, _: &Self::Output<'db>) -> bool { + unimplemented!() + } + + fn id_to_input(_: &Self::DbView, _: Id) -> Self::Input<'_> { + unimplemented!() + } + + fn execute<'db>(_: &'db Self::DbView, _: Self::Input<'db>) -> Self::Output<'db> { + unimplemented!() + } + + fn cycle_initial<'db>(_: &'db Self::DbView, _: Self::Input<'db>) -> Self::Output<'db> { + unimplemented!() + } + + fn recover_from_cycle<'db>( + _: &'db Self::DbView, + _: &Self::Output<'db>, + _: u32, + _: Self::Input<'db>, + ) -> CycleRecoveryAction> { + unimplemented!() + } + } +} diff --git a/src/ingredient.rs b/src/ingredient.rs index a5e233df9..ff4837694 100644 --- a/src/ingredient.rs +++ b/src/ingredient.rs @@ -185,7 +185,7 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync { /// Returns memory usage information about any instances of the ingredient, /// if applicable. #[cfg(feature = "salsa_unstable")] - fn memory_usage(&self, _db: &dyn Database) -> Option> { + fn memory_usage(&self, _db: &dyn Database) -> Option> { None } } diff --git a/src/input.rs b/src/input.rs index 814b6e9a4..fe72c7e16 100644 --- a/src/input.rs +++ b/src/input.rs @@ -244,7 +244,7 @@ impl Ingredient for IngredientImpl { /// Returns memory usage information about any inputs. #[cfg(feature = "salsa_unstable")] - fn memory_usage(&self, db: &dyn Database) -> Option> { + fn memory_usage(&self, db: &dyn Database) -> Option> { let memory_usage = self .entries(db) // SAFETY: The memo table belongs to a value that we allocated, so it @@ -303,11 +303,11 @@ where /// /// The `MemoTable` must belong to a `Value` of the correct type. #[cfg(feature = "salsa_unstable")] - unsafe fn memory_usage(&self, memo_table_types: &MemoTableTypes) -> crate::SlotInfo { + unsafe fn memory_usage(&self, memo_table_types: &MemoTableTypes) -> crate::database::SlotInfo { // SAFETY: The caller guarantees this is the correct types table. let memos = unsafe { memo_table_types.attach_memos(&self.memos) }; - crate::SlotInfo { + crate::database::SlotInfo { debug_name: C::DEBUG_NAME, size_of_metadata: std::mem::size_of::() - std::mem::size_of::(), size_of_fields: std::mem::size_of::(), diff --git a/src/interned.rs b/src/interned.rs index 18f0d56cb..2138494c0 100644 --- a/src/interned.rs +++ b/src/interned.rs @@ -198,14 +198,14 @@ where /// The `MemoTable` must belong to a `Value` of the correct type. Additionally, the /// lock must be held for the shard containing the value. #[cfg(all(not(feature = "shuttle"), feature = "salsa_unstable"))] - unsafe fn memory_usage(&self, memo_table_types: &MemoTableTypes) -> crate::SlotInfo { + unsafe fn memory_usage(&self, memo_table_types: &MemoTableTypes) -> crate::database::SlotInfo { // SAFETY: The caller guarantees we hold the lock for the shard containing the value, so we // have at-least read-only access to the value's memos. let memos = unsafe { &*self.memos.get() }; // SAFETY: The caller guarantees this is the correct types table. let memos = unsafe { memo_table_types.attach_memos(memos) }; - crate::SlotInfo { + crate::database::SlotInfo { debug_name: C::DEBUG_NAME, size_of_metadata: std::mem::size_of::() - std::mem::size_of::>(), size_of_fields: std::mem::size_of::>(), @@ -855,7 +855,7 @@ where /// Returns memory usage information about any interned values. #[cfg(all(not(feature = "shuttle"), feature = "salsa_unstable"))] - fn memory_usage(&self, db: &dyn Database) -> Option> { + fn memory_usage(&self, db: &dyn Database) -> Option> { use parking_lot::lock_api::RawMutex; for shard in self.shards.iter() { diff --git a/src/lib.rs b/src/lib.rs index bf12206fc..83e600771 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,7 +40,7 @@ pub use parallel::{join, par_map}; pub use salsa_macros::{accumulator, db, input, interned, tracked, Supertype, Update}; #[cfg(feature = "salsa_unstable")] -pub use self::database::{IngredientInfo, SlotInfo}; +pub use self::database::IngredientInfo; pub use self::accumulator::Accumulator; pub use self::active_query::Backtrace; diff --git a/src/table/memo.rs b/src/table/memo.rs index 821f7c4ee..2ee10134f 100644 --- a/src/table/memo.rs +++ b/src/table/memo.rs @@ -24,7 +24,7 @@ pub trait Memo: Any + Send + Sync { /// Returns memory usage information about the memoized value. #[cfg(feature = "salsa_unstable")] - fn memory_usage(&self) -> crate::SlotInfo; + fn memory_usage(&self) -> crate::database::MemoInfo; } /// Data for a memoized entry. @@ -112,12 +112,15 @@ impl Memo for DummyMemo { } #[cfg(feature = "salsa_unstable")] - fn memory_usage(&self) -> crate::SlotInfo { - crate::SlotInfo { + fn memory_usage(&self) -> crate::database::MemoInfo { + crate::database::MemoInfo { debug_name: "dummy", - size_of_metadata: 0, - size_of_fields: 0, - memos: Vec::new(), + output: crate::database::SlotInfo { + debug_name: "dummy", + size_of_metadata: 0, + size_of_fields: 0, + memos: Vec::new(), + }, } } } @@ -279,7 +282,7 @@ impl MemoTableWithTypes<'_> { } #[cfg(feature = "salsa_unstable")] - pub(crate) fn memory_usage(&self) -> Vec { + pub(crate) fn memory_usage(&self) -> Vec { let mut memory_usage = Vec::new(); let memos = self.memos.memos.read(); for (index, memo) in memos.iter().enumerate() { diff --git a/src/tracked_struct.rs b/src/tracked_struct.rs index 83817021d..ef7f9926f 100644 --- a/src/tracked_struct.rs +++ b/src/tracked_struct.rs @@ -855,7 +855,7 @@ where /// Returns memory usage information about any tracked structs. #[cfg(feature = "salsa_unstable")] - fn memory_usage(&self, db: &dyn Database) -> Option> { + fn memory_usage(&self, db: &dyn Database) -> Option> { let memory_usage = self .entries(db) // SAFETY: The memo table belongs to a value that we allocated, so it @@ -929,11 +929,11 @@ where /// /// The `MemoTable` must belong to a `Value` of the correct type. #[cfg(feature = "salsa_unstable")] - unsafe fn memory_usage(&self, memo_table_types: &MemoTableTypes) -> crate::SlotInfo { + unsafe fn memory_usage(&self, memo_table_types: &MemoTableTypes) -> crate::database::SlotInfo { // SAFETY: The caller guarantees this is the correct types table. let memos = unsafe { memo_table_types.attach_memos(&self.memos) }; - crate::SlotInfo { + crate::database::SlotInfo { debug_name: C::DEBUG_NAME, size_of_metadata: mem::size_of::() - mem::size_of::>(), size_of_fields: mem::size_of::>(), diff --git a/tests/compile-fail/accumulator_incompatibles.rs b/tests/compile-fail/accumulator_incompatibles.rs index b6be5deaa..35deb971e 100644 --- a/tests/compile-fail/accumulator_incompatibles.rs +++ b/tests/compile-fail/accumulator_incompatibles.rs @@ -25,4 +25,7 @@ struct AccWithRevisions(u32); #[salsa::accumulator(constructor = Constructor)] struct AccWithConstructor(u32); +#[salsa::accumulator(heap_size = size)] +struct AccWithHeapSize(u32); + fn main() {} diff --git a/tests/compile-fail/accumulator_incompatibles.stderr b/tests/compile-fail/accumulator_incompatibles.stderr index 1d336dc6e..6d885c487 100644 --- a/tests/compile-fail/accumulator_incompatibles.stderr +++ b/tests/compile-fail/accumulator_incompatibles.stderr @@ -51,3 +51,9 @@ error: `constructor` option not allowed here | 25 | #[salsa::accumulator(constructor = Constructor)] | ^^^^^^^^^^^ + +error: `heap_size` option not allowed here + --> tests/compile-fail/accumulator_incompatibles.rs:28:22 + | +28 | #[salsa::accumulator(heap_size = size)] + | ^^^^^^^^^ diff --git a/tests/compile-fail/input_struct_incompatibles.rs b/tests/compile-fail/input_struct_incompatibles.rs index 31ca9abb8..98cdb916d 100644 --- a/tests/compile-fail/input_struct_incompatibles.rs +++ b/tests/compile-fail/input_struct_incompatibles.rs @@ -25,4 +25,7 @@ struct InputWithTrackedField { field: u32, } +#[salsa::input(heap_size = size)] +struct InputWithHeapSize(u32); + fn main() {} diff --git a/tests/compile-fail/input_struct_incompatibles.stderr b/tests/compile-fail/input_struct_incompatibles.stderr index a1b94e9aa..9fe025275 100644 --- a/tests/compile-fail/input_struct_incompatibles.stderr +++ b/tests/compile-fail/input_struct_incompatibles.stderr @@ -47,6 +47,12 @@ error: `#[tracked]` cannot be used with `#[salsa::input]` 25 | | field: u32, | |______________^ +error: `heap_size` option not allowed here + --> tests/compile-fail/input_struct_incompatibles.rs:28:16 + | +28 | #[salsa::input(heap_size = size)] + | ^^^^^^^^^ + error: cannot find attribute `tracked` in this scope --> tests/compile-fail/input_struct_incompatibles.rs:24:7 | diff --git a/tests/compile-fail/interned_struct_incompatibles.rs b/tests/compile-fail/interned_struct_incompatibles.rs index 435335b18..b8d504282 100644 --- a/tests/compile-fail/interned_struct_incompatibles.rs +++ b/tests/compile-fail/interned_struct_incompatibles.rs @@ -39,4 +39,9 @@ struct InternedWithZeroRevisions { field: u32, } +#[salsa::interned(heap_size = size)] +struct AccWithHeapSize { + field: u32, +} + fn main() {} diff --git a/tests/compile-fail/interned_struct_incompatibles.stderr b/tests/compile-fail/interned_struct_incompatibles.stderr index 482e38b46..76ccc7f8b 100644 --- a/tests/compile-fail/interned_struct_incompatibles.stderr +++ b/tests/compile-fail/interned_struct_incompatibles.stderr @@ -41,6 +41,12 @@ error: `#[tracked]` cannot be used with `#[salsa::interned]` 34 | | field: u32, | |______________^ +error: `heap_size` option not allowed here + --> tests/compile-fail/interned_struct_incompatibles.rs:42:19 + | +42 | #[salsa::interned(heap_size = size)] + | ^^^^^^^^^ + error: cannot find attribute `tracked` in this scope --> tests/compile-fail/interned_struct_incompatibles.rs:33:7 | diff --git a/tests/compile-fail/tracked_struct_incompatibles.rs b/tests/compile-fail/tracked_struct_incompatibles.rs index 5abd62dcc..eff1eebd1 100644 --- a/tests/compile-fail/tracked_struct_incompatibles.rs +++ b/tests/compile-fail/tracked_struct_incompatibles.rs @@ -33,4 +33,9 @@ struct TrackedStructWithRevisions { field: u32, } +#[salsa::tracked(heap_size = size)] +struct TrackedStructWithHeapSize { + field: u32, +} + fn main() {} diff --git a/tests/compile-fail/tracked_struct_incompatibles.stderr b/tests/compile-fail/tracked_struct_incompatibles.stderr index 928bbb126..e27777ca0 100644 --- a/tests/compile-fail/tracked_struct_incompatibles.stderr +++ b/tests/compile-fail/tracked_struct_incompatibles.stderr @@ -39,3 +39,9 @@ error: `revisions` option not allowed here | 31 | #[salsa::tracked(revisions = 12)] | ^^^^^^^^^ + +error: `heap_size` option not allowed here + --> tests/compile-fail/tracked_struct_incompatibles.rs:36:18 + | +36 | #[salsa::tracked(heap_size = size)] + | ^^^^^^^^^ diff --git a/tests/memory-usage.rs b/tests/memory-usage.rs index c16a94643..a990ff6a3 100644 --- a/tests/memory-usage.rs +++ b/tests/memory-usage.rs @@ -25,6 +25,20 @@ fn input_to_tracked<'db>(db: &'db dyn salsa::Database, input: MyInput) -> MyTrac MyTracked::new(db, input.field(db)) } +#[salsa::tracked] +fn input_to_string<'db>(_db: &'db dyn salsa::Database) -> String { + "a".repeat(1000) +} + +#[salsa::tracked(heap_size = string_heap_size)] +fn input_to_string_get_size<'db>(_db: &'db dyn salsa::Database) -> String { + "a".repeat(1000) +} + +fn string_heap_size(x: &String) -> usize { + x.capacity() +} + #[salsa::tracked] fn input_to_tracked_tuple<'db>( db: &'db dyn salsa::Database, @@ -53,6 +67,9 @@ fn test() { let _interned2 = input_to_interned(&db, input2); let _interned3 = input_to_interned(&db, input3); + let _string1 = input_to_string(&db); + let _string2 = input_to_string_get_size(&db); + let structs_info = ::structs_info(&db); let expected = expect![[r#" @@ -75,6 +92,18 @@ fn test() { size_of_metadata: 156, size_of_fields: 12, }, + IngredientInfo { + debug_name: "input_to_string::interned_arguments", + count: 1, + size_of_metadata: 56, + size_of_fields: 0, + }, + IngredientInfo { + debug_name: "input_to_string_get_size::interned_arguments", + count: 1, + size_of_metadata: 56, + size_of_fields: 0, + }, ]"#]]; expected.assert_eq(&format!("{structs_info:#?}")); @@ -87,34 +116,34 @@ fn test() { let expected = expect![[r#" [ ( - ( - "MyInput", - "(memory_usage::MyTracked, memory_usage::MyTracked)", - ), + "input_to_interned", IngredientInfo { - debug_name: "(memory_usage::MyTracked, memory_usage::MyTracked)", - count: 1, - size_of_metadata: 132, - size_of_fields: 16, + debug_name: "memory_usage::MyInterned", + count: 3, + size_of_metadata: 192, + size_of_fields: 24, }, ), ( - ( - "MyInput", - "memory_usage::MyInterned", - ), + "input_to_string", IngredientInfo { - debug_name: "memory_usage::MyInterned", - count: 3, - size_of_metadata: 192, + debug_name: "alloc::string::String", + count: 1, + size_of_metadata: 40, size_of_fields: 24, }, ), ( - ( - "MyInput", - "memory_usage::MyTracked", - ), + "input_to_string_get_size", + IngredientInfo { + debug_name: "alloc::string::String", + count: 1, + size_of_metadata: 40, + size_of_fields: 1024, + }, + ), + ( + "input_to_tracked", IngredientInfo { debug_name: "memory_usage::MyTracked", count: 2, @@ -122,6 +151,15 @@ fn test() { size_of_fields: 16, }, ), + ( + "input_to_tracked_tuple", + IngredientInfo { + debug_name: "(memory_usage::MyTracked, memory_usage::MyTracked)", + count: 1, + size_of_metadata: 132, + size_of_fields: 16, + }, + ), ]"#]]; expected.assert_eq(&format!("{queries_info:#?}"));