Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/calc/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ impl CalcDatabaseImpl {
}

#[cfg(test)]
#[allow(unused)]
pub fn take_logs(&self) -> Vec<String> {
let mut logs = self.logs.lock().unwrap();
if let Some(logs) = &mut *logs {
Expand Down
13 changes: 9 additions & 4 deletions src/active_query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,18 @@ use crate::accumulator::{
accumulated_map::{AccumulatedMap, AtomicInputAccumulatedValues, InputAccumulatedValues},
Accumulator,
};
use crate::cycle::{CycleHeads, IterationCount};
use crate::durability::Durability;
use crate::hash::FxIndexSet;
use crate::key::DatabaseKeyIndex;
use crate::runtime::Stamp;
use crate::sync::atomic::AtomicBool;
use crate::tracked_struct::{Disambiguator, DisambiguatorMap, IdentityHash, IdentityMap};
use crate::zalsa_local::{QueryEdge, QueryOrigin, QueryRevisions, QueryRevisionsExtra};
use crate::Revision;
use crate::{
cycle::{CycleHeads, IterationCount},
Id,
};
use crate::{durability::Durability, tracked_struct::Identity};

#[derive(Debug)]
pub(crate) struct ActiveQuery {
Expand Down Expand Up @@ -74,6 +77,7 @@ impl ActiveQuery {
changed_at: Revision,
edges: &[QueryEdge],
untracked_read: bool,
active_tracked_ids: &[(Identity, Id)],
) {
assert!(self.input_outputs.is_empty());

Expand All @@ -83,7 +87,8 @@ impl ActiveQuery {
self.untracked_read |= untracked_read;

// Mark all tracked structs from the previous iteration as active.
self.tracked_struct_ids.mark_all_active();
self.tracked_struct_ids
.mark_all_active(active_tracked_ids.iter().copied());
}

pub(super) fn add_read(
Expand Down Expand Up @@ -408,7 +413,7 @@ pub(crate) struct CompletedQuery {

/// The keys of any tracked structs that were created in a previous execution of the
/// query but not the current one, and should be marked as stale.
pub(crate) stale_tracked_structs: Vec<DatabaseKeyIndex>,
pub(crate) stale_tracked_structs: Vec<(Identity, Id)>,
}

struct CapturedQuery {
Expand Down
5 changes: 3 additions & 2 deletions src/function/diff_outputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,9 @@ where

// Note that tracked structs are not stored as direct query outputs, but they are still outputs
// that need to be reported as stale.
for output in &completed_query.stale_tracked_structs {
Self::report_stale_output(zalsa, key, *output);
for (identity, id) in &completed_query.stale_tracked_structs {
let output = DatabaseKeyIndex::new(identity.ingredient_index(), *id);
Self::report_stale_output(zalsa, key, output);
}

let mut stale_outputs = output_edges(edges).collect::<FxIndexSet<_>>();
Expand Down
22 changes: 16 additions & 6 deletions src/function/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use crate::cycle::{CycleRecoveryStrategy, IterationCount};
use crate::function::memo::Memo;
use crate::function::{Configuration, IngredientImpl};
use crate::sync::atomic::{AtomicBool, Ordering};
use crate::tracked_struct::Identity;
use crate::zalsa::{MemoIngredientIndex, Zalsa, ZalsaDatabase};
use crate::zalsa_local::ActiveQueryGuard;
use crate::{Event, EventKind, Id};
Expand Down Expand Up @@ -134,13 +135,25 @@ where
let database_key_index = active_query.database_key_index;
let mut iteration_count = IterationCount::initial();
let mut fell_back = false;
let zalsa_local = db.zalsa_local();

// Our provisional value from the previous iteration, when doing fixpoint iteration.
// Initially it's set to None, because the initial provisional value is created lazily,
// only when a cycle is actually encountered.
let mut opt_last_provisional: Option<&Memo<'db, C>> = None;
let mut last_stale_tracked_ids: Vec<(Identity, Id)> = Vec::new();

loop {
let previous_memo = opt_last_provisional.or(opt_old_memo);

// Tracked struct ids that existed in the previous revision
// but weren't recreated in the last iteration. It's important that we seed the next
// query with these ids because the query might re-create them as part of the next iteration.
// This is not only important to ensure that the re-created tracked structs have the same ids,
// it's also important to ensure that these tracked structs get removed
// if they aren't recreated when reaching the final iteration.
active_query.seed_tracked_struct_ids(&last_stale_tracked_ids);

let (mut new_value, mut completed_query) =
Self::execute_query(db, zalsa, active_query, previous_memo, id);

Expand Down Expand Up @@ -239,10 +252,9 @@ where
),
memo_ingredient_index,
));
last_stale_tracked_ids = completed_query.stale_tracked_structs;

active_query = db
.zalsa_local()
.push_query(database_key_index, iteration_count);
active_query = zalsa_local.push_query(database_key_index, iteration_count);

continue;
}
Expand Down Expand Up @@ -280,9 +292,7 @@ where
if let Some(old_memo) = opt_old_memo {
// If we already executed this query once, then use the tracked-struct ids from the
// previous execution as the starting point for the new one.
if let Some(tracked_struct_ids) = old_memo.revisions.tracked_struct_ids() {
active_query.seed_tracked_struct_ids(tracked_struct_ids);
}
active_query.seed_tracked_struct_ids(old_memo.revisions.tracked_struct_ids());

// Copy over all inputs and outputs from a previous iteration.
// This is necessary to:
Expand Down
2 changes: 1 addition & 1 deletion src/function/memo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ where
stale_output.remove_stale_output(zalsa, executor);
}

for (identity, id) in self.revisions.tracked_struct_ids().into_iter().flatten() {
for (identity, id) in self.revisions.tracked_struct_ids() {
let key = DatabaseKeyIndex::new(identity.ingredient_index(), *id);
key.remove_stale_output(zalsa, executor);
}
Expand Down
22 changes: 7 additions & 15 deletions src/tracked_struct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,19 +255,15 @@ pub(crate) struct IdentityMap {
impl IdentityMap {
/// Seeds the identity map with the IDs from a previous revision.
pub(crate) fn seed(&mut self, source: &[(Identity, Id)]) {
self.table.clear();
self.table
.reserve(source.len(), |entry| entry.identity.hash);

for &(key, id) in source {
self.insert_entry(key, id, false);
}
}

// Mark all tracked structs in the map as created by the current query.
pub(crate) fn mark_all_active(&mut self) {
for entry in self.table.iter_mut() {
entry.active = true;
pub(crate) fn mark_all_active(&mut self, items: impl IntoIterator<Item = (Identity, Id)>) {
for (key, id) in items {
self.insert_entry(key, id, true);
}
}

Expand Down Expand Up @@ -330,7 +326,8 @@ impl IdentityMap {
/// The first entry contains the identity and IDs of any tracked structs that were
/// created by the current execution of the query, while the second entry contains any
/// tracked structs that were created in a previous execution but not the current one.
pub(crate) fn drain(&mut self) -> (ThinVec<(Identity, Id)>, Vec<DatabaseKeyIndex>) {
#[expect(clippy::type_complexity)]
pub(crate) fn drain(&mut self) -> (ThinVec<(Identity, Id)>, Vec<(Identity, Id)>) {
if self.table.is_empty() {
return (ThinVec::new(), Vec::new());
}
Expand All @@ -342,19 +339,14 @@ impl IdentityMap {
if entry.active {
active.push((entry.identity, entry.id));
} else {
stale.push(DatabaseKeyIndex::new(
entry.identity.ingredient_index(),
entry.id,
));
stale.push((entry.identity, entry.id));
}
}

// Removing a stale tracked struct ID shows up in the event logs, so make sure
// the order is stable here.
stale.sort_unstable_by(|a, b| {
a.ingredient_index()
.cmp(&b.ingredient_index())
.then(a.key_index().cmp(&b.key_index()))
(a.0.ingredient_index(), a.1).cmp(&(b.0.ingredient_index(), b.1))
});

(active, stale)
Expand Down
10 changes: 5 additions & 5 deletions src/zalsa_local.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,13 +668,13 @@ impl QueryRevisions {
}
}

/// Returns a reference to the `IdentityMap` for this query, or `None` if the map is empty.
pub fn tracked_struct_ids(&self) -> Option<&[(Identity, Id)]> {
/// Returns the ids of the tracked structs created when running this query.
pub fn tracked_struct_ids(&self) -> &[(Identity, Id)] {
self.extra
.0
.as_ref()
.map(|extra| &*extra.tracked_struct_ids)
.filter(|tracked_struct_ids| !tracked_struct_ids.is_empty())
.unwrap_or_default()
}

/// Returns a mutable reference to the `IdentityMap` for this query, or `None` if the map is empty.
Expand Down Expand Up @@ -1090,7 +1090,6 @@ impl ActiveQueryGuard<'_> {
#[cfg(debug_assertions)]
assert_eq!(stack.len(), self.push_len);
let frame = stack.last_mut().unwrap();
assert!(frame.tracked_struct_ids().is_empty());
frame.tracked_struct_ids_mut().seed(tracked_struct_ids);
})
}
Expand All @@ -1105,14 +1104,15 @@ impl ActiveQueryGuard<'_> {
previous.origin.as_ref(),
QueryOriginRef::DerivedUntracked(_)
);
let tracked_ids = previous.tracked_struct_ids();

// SAFETY: We do not access the query stack reentrantly.
unsafe {
self.local_state.with_query_stack_unchecked_mut(|stack| {
#[cfg(debug_assertions)]
assert_eq!(stack.len(), self.push_len);
let frame = stack.last_mut().unwrap();
frame.seed_iteration(durability, changed_at, edges, untracked_read);
frame.seed_iteration(durability, changed_at, edges, untracked_read, tracked_ids);
})
}
}
Expand Down
131 changes: 127 additions & 4 deletions tests/cycle_tracked.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
#![cfg(feature = "inventory")]

//! Tests for cycles where the cycle head is stored on a tracked struct
//! and that tracked struct is freed in a later revision.

mod common;

use crate::common::{EventLoggerDatabase, LogDatabase};
Expand Down Expand Up @@ -45,6 +42,7 @@ struct Node<'db> {
#[salsa::input(debug)]
struct GraphInput {
simple: bool,
fixpoint_variant: usize,
}

#[salsa::tracked(returns(ref))]
Expand Down Expand Up @@ -125,11 +123,13 @@ fn cycle_recover(
CycleRecoveryAction::Iterate
}

/// Tests for cycles where the cycle head is stored on a tracked struct
/// and that tracked struct is freed in a later revision.
#[test]
fn main() {
let mut db = EventLoggerDatabase::default();

let input = GraphInput::new(&db, false);
let input = GraphInput::new(&db, false, 0);
let graph = create_graph(&db, input);
let c = graph.find_node(&db, "c").unwrap();

Expand Down Expand Up @@ -192,3 +192,126 @@ fn main() {
"WillCheckCancellation",
]"#]]);
}

#[salsa::tracked]
struct IterationNode<'db> {
#[returns(ref)]
name: String,
iteration: usize,
}

/// A cyclic query that creates more tracked structs in later fixpoint iterations.
///
/// The output depends on the input's fixpoint_variant:
/// - variant=0: Returns `[base]` (1 struct, no cycle)
/// - variant=1: Through fixpoint iteration, returns `[iter_0, iter_1, iter_2]` (3 structs)
/// - variant=2: Through fixpoint iteration, returns `[iter_0, iter_1]` (2 structs)
/// - variant>2: Through fixpoint iteration, returns `[iter_0, iter_1]` (2 structs, same as variant=2)
///
/// When variant > 0, the query creates a cycle by calling itself. The fixpoint iteration
/// proceeds as follows:
/// 1. Initial: returns empty vector
/// 2. First iteration: returns `[iter_0]`
/// 3. Second iteration: returns `[iter_0, iter_1]`
/// 4. Third iteration (only for variant=1): returns `[iter_0, iter_1, iter_2]`
/// 5. Further iterations: no change, fixpoint reached
#[salsa::tracked(cycle_fn=cycle_recover_with_structs, cycle_initial=initial_with_structs)]
fn create_tracked_in_cycle<'db>(
db: &'db dyn Database,
input: GraphInput,
) -> Vec<IterationNode<'db>> {
// Check if we should create more nodes based on the input.
let variant = input.fixpoint_variant(db);

if variant == 0 {
// Base case - no cycle, just return a single node.
vec![IterationNode::new(db, "base".to_string(), 0)]
} else {
// Create a cycle by calling ourselves.
let previous = create_tracked_in_cycle(db, input);

// In later iterations, create additional tracked structs.
if previous.is_empty() {
// First iteration - initial returns empty.
vec![IterationNode::new(db, "iter_0".to_string(), 0)]
} else {
// Limit based on variant: variant=1 allows 3 nodes, variant=2 allows 2 nodes.
let limit = if variant == 1 { 3 } else { 2 };

if previous.len() < limit {
// Subsequent iterations - add more nodes.
let mut nodes = previous;
nodes.push(IterationNode::new(
db,
format!("iter_{}", nodes.len()),
nodes.len(),
));
nodes
} else {
// Reached the limit.
previous
}
}
}
}

fn initial_with_structs(_db: &dyn Database, _input: GraphInput) -> Vec<IterationNode<'_>> {
vec![]
}

#[allow(clippy::ptr_arg)]
fn cycle_recover_with_structs<'db>(
_db: &'db dyn Database,
_value: &Vec<IterationNode<'db>>,
_iteration: u32,
_input: GraphInput,
) -> CycleRecoveryAction<Vec<IterationNode<'db>>> {
CycleRecoveryAction::Iterate
}

#[test]
fn test_cycle_with_fixpoint_structs() {
let mut db = EventLoggerDatabase::default();

// Create an input that will trigger the cyclic behavior.
let input = GraphInput::new(&db, false, 1);

// Initial query - this will create structs across multiple iterations.
let nodes = create_tracked_in_cycle(&db, input);
assert_eq!(nodes.len(), 3);
// First iteration: previous is empty [], so we get [iter_0]
// Second iteration: previous is [iter_0], so we get [iter_0, iter_1]
// Third iteration: previous is [iter_0, iter_1], so we get [iter_0, iter_1, iter_2]
assert_eq!(nodes[0].name(&db), "iter_0");
assert_eq!(nodes[1].name(&db), "iter_1");
assert_eq!(nodes[2].name(&db), "iter_2");

// Clear logs to focus on the change.
db.clear_logs();

// Change the input to force re-execution with a different variant.
// This will create 2 tracked structs instead of 3 (one fewer than before).
input.set_fixpoint_variant(&mut db).to(2);

// Re-query - this should handle the tracked struct changes properly.
let nodes = create_tracked_in_cycle(&db, input);
assert_eq!(nodes.len(), 2);
assert_eq!(nodes[0].name(&db), "iter_0");
assert_eq!(nodes[1].name(&db), "iter_1");

// Check the logs to ensure proper execution and struct management.
// We should see the third struct (iter_2) being discarded.
db.assert_logs(expect![[r#"
[
"DidSetCancellationFlag",
"WillCheckCancellation",
"WillExecute { database_key: create_tracked_in_cycle(Id(0)) }",
"WillCheckCancellation",
"WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(1), fell_back: false }",
"WillCheckCancellation",
"WillIterateCycle { database_key: create_tracked_in_cycle(Id(0)), iteration_count: IterationCount(2), fell_back: false }",
"WillCheckCancellation",
"WillDiscardStaleOutput { execute_key: create_tracked_in_cycle(Id(0)), output_key: IterationNode(Id(402)) }",
"DidDiscard { key: IterationNode(Id(402)) }",
]"#]]);
}
Loading