Skip to content

Commit 664750a

Browse files
authored
Track cycle function dependencies as part of the cyclic query (#1018)
* Track cycle function dependenciees as part of the cyclic query * Add regression test * Discard changes to src/function/backdate.rs * Update comment * Fix merge error * Refine comment
1 parent c762869 commit 664750a

File tree

6 files changed

+149
-19
lines changed

6 files changed

+149
-19
lines changed

src/active_query.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,10 @@ impl ActiveQuery {
9191
.mark_all_active(active_tracked_ids.iter().copied());
9292
}
9393

94+
pub(super) fn take_cycle_heads(&mut self) -> CycleHeads {
95+
std::mem::take(&mut self.cycle_heads)
96+
}
97+
9498
pub(super) fn add_read(
9599
&mut self,
96100
input: DatabaseKeyIndex,

src/cycle.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,4 +490,8 @@ impl<'db> ProvisionalStatus<'db> {
490490
_ => empty_cycle_heads(),
491491
}
492492
}
493+
494+
pub(crate) const fn is_provisional(&self) -> bool {
495+
matches!(self, ProvisionalStatus::Provisional { .. })
496+
}
493497
}

src/function/execute.rs

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,25 @@ where
5656
});
5757

5858
let (new_value, mut completed_query) = match C::CYCLE_STRATEGY {
59-
CycleRecoveryStrategy::Panic => Self::execute_query(
60-
db,
61-
zalsa,
62-
zalsa_local.push_query(database_key_index, IterationCount::initial()),
63-
opt_old_memo,
64-
),
59+
CycleRecoveryStrategy::Panic => {
60+
let (new_value, active_query) = Self::execute_query(
61+
db,
62+
zalsa,
63+
zalsa_local.push_query(database_key_index, IterationCount::initial()),
64+
opt_old_memo,
65+
);
66+
(new_value, active_query.pop())
67+
}
6568
CycleRecoveryStrategy::FallbackImmediate => {
66-
let (mut new_value, mut completed_query) = Self::execute_query(
69+
let (mut new_value, active_query) = Self::execute_query(
6770
db,
6871
zalsa,
6972
zalsa_local.push_query(database_key_index, IterationCount::initial()),
7073
opt_old_memo,
7174
);
7275

76+
let mut completed_query = active_query.pop();
77+
7378
if let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() {
7479
// Did the new result we got depend on our own provisional value, in a cycle?
7580
if cycle_heads.contains(&database_key_index) {
@@ -198,9 +203,10 @@ where
198203

199204
let _poison_guard =
200205
PoisonProvisionalIfPanicking::new(self, zalsa, id, memo_ingredient_index);
201-
let mut active_query = zalsa_local.push_query(database_key_index, iteration_count);
202206

203207
let (new_value, completed_query) = loop {
208+
let active_query = zalsa_local.push_query(database_key_index, iteration_count);
209+
204210
// Tracked struct ids that existed in the previous revision
205211
// but weren't recreated in the last iteration. It's important that we seed the next
206212
// query with these ids because the query might re-create them as part of the next iteration.
@@ -209,29 +215,32 @@ where
209215
// if they aren't recreated when reaching the final iteration.
210216
active_query.seed_tracked_struct_ids(&last_stale_tracked_ids);
211217

212-
let (mut new_value, mut completed_query) = Self::execute_query(
218+
let (mut new_value, mut active_query) = Self::execute_query(
213219
db,
214220
zalsa,
215221
active_query,
216222
last_provisional_memo.or(opt_old_memo),
217223
);
218224

219-
// If there are no cycle heads, break out of the loop (`cycle_heads_mut` returns `None` if the cycle head list is empty)
220-
let Some(cycle_heads) = completed_query.revisions.cycle_heads_mut() else {
225+
// Take the cycle heads to not-fight-rust's-borrow-checker.
226+
let mut cycle_heads = active_query.take_cycle_heads();
227+
228+
// If there are no cycle heads, break out of the loop.
229+
if cycle_heads.is_empty() {
221230
iteration_count = iteration_count.increment().unwrap_or_else(|| {
222231
tracing::warn!("{database_key_index:?}: execute: too many cycle iterations");
223232
panic!("{database_key_index:?}: execute: too many cycle iterations")
224233
});
234+
235+
let mut completed_query = active_query.pop();
225236
completed_query
226237
.revisions
227238
.update_iteration_count_mut(database_key_index, iteration_count);
228239

229240
claim_guard.set_release_mode(ReleaseMode::SelfOnly);
230241
break (new_value, completed_query);
231-
};
242+
}
232243

233-
// Take the cycle heads to not-fight-rust's-borrow-checker.
234-
let mut cycle_heads = std::mem::take(cycle_heads);
235244
let mut missing_heads: SmallVec<[(DatabaseKeyIndex, IterationCount); 1]> =
236245
SmallVec::new_const();
237246
let mut max_iteration_count = iteration_count;
@@ -262,6 +271,11 @@ where
262271
.provisional_status(zalsa, head.database_key_index.key_index())
263272
.expect("cycle head memo must have been created during the execution");
264273

274+
// A query should only ever depend on other heads that are provisional.
275+
// If this invariant is violated, it means that this query participates in a cycle,
276+
// but it wasn't executed in the last iteration of said cycle.
277+
assert!(provisional_status.is_provisional());
278+
265279
for nested_head in provisional_status.cycle_heads() {
266280
let nested_as_tuple = (
267281
nested_head.database_key_index,
@@ -298,6 +312,8 @@ where
298312
claim_guard.set_release_mode(ReleaseMode::SelfOnly);
299313
}
300314

315+
let mut completed_query = active_query.pop();
316+
*completed_query.revisions.verified_final.get_mut() = false;
301317
completed_query.revisions.set_cycle_heads(cycle_heads);
302318

303319
iteration_count = iteration_count.increment().unwrap_or_else(|| {
@@ -378,8 +394,17 @@ where
378394
this_converged = C::values_equal(&new_value, last_provisional_value);
379395
}
380396
}
397+
398+
let new_cycle_heads = active_query.take_cycle_heads();
399+
for head in new_cycle_heads {
400+
if !cycle_heads.contains(&head.database_key_index) {
401+
panic!("Cycle recovery function for {database_key_index:?} introduced a cycle, depending on {:?}. This is not allowed.", head.database_key_index);
402+
}
403+
}
381404
}
382405

406+
let mut completed_query = active_query.pop();
407+
383408
if let Some(outer_cycle) = outer_cycle {
384409
tracing::info!(
385410
"Detected nested cycle {database_key_index:?}, iterate it as part of the outer cycle {outer_cycle:?}"
@@ -390,6 +415,7 @@ where
390415
completed_query
391416
.revisions
392417
.set_cycle_converged(this_converged);
418+
*completed_query.revisions.verified_final.get_mut() = false;
393419

394420
// Transfer ownership of this query to the outer cycle, so that it can claim it
395421
// and other threads don't compete for the same lock.
@@ -428,9 +454,9 @@ where
428454
}
429455

430456
*completed_query.revisions.verified_final.get_mut() = true;
431-
432457
break (new_value, completed_query);
433458
}
459+
*completed_query.revisions.verified_final.get_mut() = false;
434460

435461
// The fixpoint iteration hasn't converged. Iterate again...
436462
iteration_count = iteration_count.increment().unwrap_or_else(|| {
@@ -484,7 +510,6 @@ where
484510
last_provisional_memo = Some(new_memo);
485511

486512
last_stale_tracked_ids = completed_query.stale_tracked_structs;
487-
active_query = zalsa_local.push_query(database_key_index, iteration_count);
488513

489514
continue;
490515
};
@@ -503,7 +528,7 @@ where
503528
zalsa: &'db Zalsa,
504529
active_query: ActiveQueryGuard<'db>,
505530
opt_old_memo: Option<&Memo<'db, C>>,
506-
) -> (C::Output<'db>, CompletedQuery) {
531+
) -> (C::Output<'db>, ActiveQueryGuard<'db>) {
507532
if let Some(old_memo) = opt_old_memo {
508533
// If we already executed this query once, then use the tracked-struct ids from the
509534
// previous execution as the starting point for the new one.
@@ -528,7 +553,7 @@ where
528553
C::id_to_input(zalsa, active_query.database_key_index.key_index()),
529554
);
530555

531-
(new_value, active_query.pop())
556+
(new_value, active_query)
532557
}
533558
}
534559

src/function/maybe_changed_after.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -592,7 +592,10 @@ where
592592
cycle_heads.append_heads(&mut child_cycle_heads);
593593

594594
match input_result {
595-
VerifyResult::Changed => return VerifyResult::changed(),
595+
VerifyResult::Changed => {
596+
cycle_heads.remove_head(database_key_index);
597+
return VerifyResult::changed();
598+
}
596599
#[cfg(feature = "accumulator")]
597600
VerifyResult::Unchanged { accumulated } => {
598601
inputs |= accumulated;

src/zalsa_local.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1213,6 +1213,18 @@ impl ActiveQueryGuard<'_> {
12131213
}
12141214
}
12151215

1216+
pub(crate) fn take_cycle_heads(&mut self) -> CycleHeads {
1217+
// SAFETY: We do not access the query stack reentrantly.
1218+
unsafe {
1219+
self.local_state.with_query_stack_unchecked_mut(|stack| {
1220+
#[cfg(debug_assertions)]
1221+
assert_eq!(stack.len(), self.push_len);
1222+
let frame = stack.last_mut().unwrap();
1223+
frame.take_cycle_heads()
1224+
})
1225+
}
1226+
}
1227+
12161228
/// Invoked when the query has successfully completed execution.
12171229
fn complete(self) -> CompletedQuery {
12181230
// SAFETY: We do not access the query stack reentrantly.
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#![cfg(feature = "inventory")]
2+
3+
//! Queries or inputs read within the cycle recovery function
4+
//! are tracked on the cycle function and don't "leak" into the
5+
//! function calling the query with cycle handling.
6+
7+
use expect_test::expect;
8+
use salsa::Setter as _;
9+
10+
use crate::common::LogDatabase;
11+
12+
mod common;
13+
14+
#[salsa::input]
15+
struct Input {
16+
value: u32,
17+
}
18+
19+
#[salsa::tracked]
20+
fn entry(db: &dyn salsa::Database, input: Input) -> u32 {
21+
query(db, input)
22+
}
23+
24+
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=cycle_initial)]
25+
fn query(db: &dyn salsa::Database, input: Input) -> u32 {
26+
let val = query(db, input);
27+
if val < 5 {
28+
val + 1
29+
} else {
30+
val
31+
}
32+
}
33+
34+
fn cycle_initial(_db: &dyn salsa::Database, _id: salsa::Id, _input: Input) -> u32 {
35+
0
36+
}
37+
38+
fn cycle_fn(
39+
db: &dyn salsa::Database,
40+
_id: salsa::Id,
41+
_last_provisional_value: &u32,
42+
_value: &u32,
43+
_count: u32,
44+
input: Input,
45+
) -> salsa::CycleRecoveryAction<u32> {
46+
let _input = input.value(db);
47+
salsa::CycleRecoveryAction::Iterate
48+
}
49+
50+
#[test_log::test]
51+
fn the_test() {
52+
let mut db = common::EventLoggerDatabase::default();
53+
54+
let input = Input::new(&db, 1);
55+
assert_eq!(entry(&db, input), 5);
56+
57+
db.assert_logs_len(15);
58+
59+
input.set_value(&mut db).to(2);
60+
61+
assert_eq!(entry(&db, input), 5);
62+
db.assert_logs(expect![[r#"
63+
[
64+
"DidSetCancellationFlag",
65+
"WillCheckCancellation",
66+
"WillCheckCancellation",
67+
"WillCheckCancellation",
68+
"WillExecute { database_key: query(Id(0)) }",
69+
"WillCheckCancellation",
70+
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(1) }",
71+
"WillCheckCancellation",
72+
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(2) }",
73+
"WillCheckCancellation",
74+
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(3) }",
75+
"WillCheckCancellation",
76+
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(4) }",
77+
"WillCheckCancellation",
78+
"WillIterateCycle { database_key: query(Id(0)), iteration_count: IterationCount(5) }",
79+
"WillCheckCancellation",
80+
"DidValidateMemoizedValue { database_key: entry(Id(0)) }",
81+
]"#]]);
82+
}

0 commit comments

Comments
 (0)