Skip to content

Commit 0cc827d

Browse files
committed
Fix await for queries depending on initial value
1 parent 2a09313 commit 0cc827d

File tree

6 files changed

+79
-83
lines changed

6 files changed

+79
-83
lines changed

src/function.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@ use std::ptr::NonNull;
55
use std::sync::atomic::Ordering;
66

77
use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues};
8-
use crate::cycle::{CycleHeadKind, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy};
8+
use crate::cycle::{
9+
empty_cycle_heads, CycleHeadKind, CycleHeads, CycleRecoveryAction, CycleRecoveryStrategy,
10+
};
911
use crate::function::delete::DeletedEntries;
1012
use crate::function::sync::{ClaimResult, SyncTable};
1113
use crate::ingredient::Ingredient;
@@ -256,6 +258,12 @@ where
256258
}
257259
}
258260

261+
fn cycle_heads<'db>(&self, zalsa: &'db Zalsa, input: Id) -> &'db CycleHeads {
262+
self.get_memo_from_table_for(zalsa, input, self.memo_ingredient_index(zalsa, input))
263+
.map(|memo| memo.cycle_heads())
264+
.unwrap_or(empty_cycle_heads())
265+
}
266+
259267
/// Attempts to claim `key_index`, returning `false` if a cycle occurs.
260268
fn wait_for(&self, zalsa: &Zalsa, key_index: Id) -> bool {
261269
match self.sync_table.try_claim(zalsa, key_index) {

src/function/fetch.rs

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ where
202202
&& old_memo.may_be_provisional()
203203
&& old_memo.verified_at.load() == zalsa.current_revision()
204204
{
205-
old_memo.await_heads(zalsa, database_key_index);
205+
old_memo.await_heads(zalsa);
206206

207207
// It's possible that one of the cycle heads replaced the memo for this ingredient
208208
// with fixpoint initial. We ignore that memo because we know it's only a temporary memo
@@ -213,13 +213,9 @@ where
213213
if let Some(old_memo) = opt_old_memo {
214214
if old_memo.value.is_some() {
215215
let mut cycle_heads = CycleHeads::default();
216-
if let VerifyResult::Unchanged(_) = self.deep_verify_memo(
217-
db,
218-
zalsa,
219-
old_memo,
220-
database_key_index,
221-
&mut cycle_heads,
222-
) {
216+
if let VerifyResult::Unchanged(_) =
217+
self.deep_verify_memo(db, zalsa, old_memo, database_key_index, &mut cycle_heads)
218+
{
223219
if cycle_heads.is_empty() {
224220
// SAFETY: memo is present in memo_map and we have verified that it is
225221
// still valid for the current revision.

src/function/memo.rs

Lines changed: 43 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::ptr::NonNull;
55

66
use crate::cycle::{empty_cycle_heads, CycleHeadKind, CycleHeads};
77
use crate::function::{Configuration, IngredientImpl};
8+
use crate::hash::FxHashSet;
89
use crate::key::DatabaseKeyIndex;
910
use crate::revision::AtomicRevision;
1011
use crate::sync::atomic::Ordering;
@@ -99,6 +100,7 @@ pub struct Memo<V> {
99100

100101
// Memo's are stored a lot, make sure their size is doesn't randomly increase.
101102
#[cfg(not(feature = "shuttle"))]
103+
#[cfg(target_pointer_width = "64")]
102104
const _: [(); std::mem::size_of::<Memo<std::num::NonZeroUsize>>()] =
103105
[(); std::mem::size_of::<[usize; 13]>()];
104106

@@ -141,73 +143,27 @@ impl<V> Memo<V> {
141143
return false;
142144
};
143145

144-
return provisional_retry_cold(zalsa, database_key_index, &self.revisions.cycle_heads);
145-
146-
#[inline(never)]
147-
fn provisional_retry_cold(
148-
zalsa: &Zalsa,
149-
database_key_index: DatabaseKeyIndex,
150-
cycle_heads: &CycleHeads,
151-
) -> bool {
152-
let mut retry = false;
153-
let mut hit_cycle = false;
154-
155-
for head in cycle_heads {
156-
let head_index = head.database_key_index;
157-
158-
let ingredient = zalsa.lookup_ingredient(head_index.ingredient_index());
159-
let cycle_head_kind = ingredient.cycle_head_kind(zalsa, head_index.key_index());
160-
if matches!(
161-
cycle_head_kind,
162-
CycleHeadKind::NotProvisional | CycleHeadKind::FallbackImmediate
163-
) {
164-
// This cycle is already finalized, so we don't need to wait on it;
165-
// keep looping through cycle heads.
166-
retry = true;
167-
tracing::trace!("Dependent cycle head {head_index:?} has been finalized.");
168-
} else if ingredient.wait_for(zalsa, head_index.key_index()) {
169-
tracing::trace!("Dependent cycle head {head_index:?} has been released (there's a new memo)");
170-
// There's a new memo available for the cycle head; fetch our own
171-
// updated memo and see if it's still provisional or if the cycle
172-
// has resolved.
173-
tracing::trace!("Dependent cycle head {head_index:?} has been released (there's a new memo)");
174-
retry = true;
175-
} else {
176-
// We hit a cycle blocking on the cycle head; this means it's in
177-
// our own active query stack and we are responsible to resolve the
178-
// cycle, so go ahead and return the provisional memo.
179-
tracing::debug!(
180-
"Waiting for {head_index:?} results in a cycle, return {database_key_index:?} once all other cycle heads completed to allow the outer cycle to make progress."
181-
);
182-
hit_cycle = true;
183-
}
184-
}
185-
186-
// If `retry` is `true`, all our cycle heads (barring ourself) are complete; re-fetch
187-
// and we should get a non-provisional memo. If we get here and `retry` is still
188-
// `false`, we have no cycle heads other than ourself, so we are a provisional value of
189-
// the cycle head (either initial value, or from a later iteration) and should be
190-
// returned to caller to allow fixpoint iteration to proceed. (All cases in the loop
191-
// above other than "cycle head is self" are either terminal or set `retry`.)
192-
if hit_cycle {
193-
false
194-
} else if retry {
195-
tracing::debug!("Retrying {database_key_index:?}");
196-
true
197-
} else {
198-
false
199-
}
146+
if self.await_heads(zalsa) {
147+
false
148+
} else {
149+
tracing::debug!(
150+
"Retrying provisional memo {database_key_index:?} after awaiting cycle heads."
151+
);
152+
true
200153
}
201154
}
202155

203-
#[inline(always)]
204-
pub(super) fn await_heads(&self, zalsa: &Zalsa, database_key_index: DatabaseKeyIndex) {
205-
for head in &self.revisions.cycle_heads {
206-
let head_index = head.database_key_index;
156+
/// Awaits all cycle heads (recursively) that this memo depends on.
157+
///
158+
/// Returns `true` if awaiting the cycle heads resulted in a cycle.
159+
pub(super) fn await_heads(&self, zalsa: &Zalsa) -> bool {
160+
let mut hit_cycle = false;
207161

208-
if database_key_index == head_index {
209-
continue;
210-
}
162+
let mut visited = FxHashSet::default();
163+
let mut queue: Vec<_> = self.revisions.cycle_heads.iter().collect();
164+
165+
while let Some(head) = queue.pop() {
166+
let head_index = head.database_key_index;
211167

212168
let ingredient = zalsa.lookup_ingredient(head_index.ingredient_index());
213169
let cycle_head_kind = ingredient.cycle_head_kind(zalsa, head_index.key_index());
@@ -220,13 +176,35 @@ impl<V> Memo<V> {
220176
// keep looping through cycle heads.
221177
tracing::trace!("Dependent cycle head {head_index:?} has been finalized.");
222178
} else if ingredient.wait_for(zalsa, head_index.key_index()) {
223-
tracing::trace!("Dependent cycle head {head_index:?} has been released");
179+
// There's a new memo available for the cycle head; fetch our own
180+
// updated memo and see if it's still provisional or if the cycle
181+
// has resolved.
182+
tracing::trace!(
183+
"Dependent cycle head {head_index:?} has been released (there's a new memo)"
184+
);
185+
// Recursively wait for all cycle heads that this head depends on.
186+
// This is normally not necessary, because cycle heads are transitively added
187+
// as query dependencies (they aggregate). The exception to this are queries
188+
// that depend on a fixpoint initial value. We don't know all the dependencies of
189+
// the query yet, so they can't be carried over. We only know them once the cycle
190+
// completes but the cycle heads of the queries don't get updated.
191+
// Because of that, recurse here to collect all cycle heads.
192+
queue.extend(
193+
ingredient
194+
.cycle_heads(zalsa, head_index.key_index())
195+
.iter()
196+
.filter(|head| visited.insert(head.database_key_index)),
197+
);
224198
} else {
225199
// We hit a cycle blocking on the cycle head; this means it's in
226200
// our own active query stack and we are responsible to resolve the
227-
// cycle
201+
// cycle, so go ahead and return the provisional memo.
202+
tracing::debug!("Waiting for {head_index:?} results in a cycle");
203+
hit_cycle = true;
228204
}
229205
}
206+
207+
hit_cycle
230208
}
231209

232210
/// Cycle heads that should be propagated to dependent queries.

src/ingredient.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::any::{Any, TypeId};
22
use std::fmt;
33

44
use crate::accumulator::accumulated_map::{AccumulatedMap, InputAccumulatedValues};
5-
use crate::cycle::{CycleHeadKind, CycleHeads, CycleRecoveryStrategy};
5+
use crate::cycle::{empty_cycle_heads, CycleHeadKind, CycleHeads, CycleRecoveryStrategy};
66
use crate::function::VerifyResult;
77
use crate::plumbing::IngredientIndices;
88
use crate::sync::Arc;
@@ -67,14 +67,17 @@ pub trait Ingredient: Any + std::fmt::Debug + Send + Sync {
6767
) -> VerifyResult;
6868

6969
/// Is the value for `input` in this ingredient a cycle head that is still provisional?
70-
///
71-
/// In the case of nested cycles, we are not asking here whether the value is provisional due
72-
/// to the outer cycle being unresolved, only whether its own cycle remains provisional.
7370
fn cycle_head_kind(&self, zalsa: &Zalsa, input: Id) -> CycleHeadKind {
7471
_ = (zalsa, input);
7572
CycleHeadKind::NotProvisional
7673
}
7774

75+
/// Returns the cycle heads for this ingredient.
76+
fn cycle_heads<'db>(&self, zalsa: &'db Zalsa, input: Id) -> &'db CycleHeads {
77+
_ = (zalsa, input);
78+
empty_cycle_heads()
79+
}
80+
7881
/// Invoked when the current thread needs to wait for a result for the given `key_index`.
7982
///
8083
/// A return value of `true` indicates that a result is now available. A return value of

tests/parallel/cycle_a_t1_b_t2_fallback.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
//! +--------------------+
1313
//! ```
1414
use crate::sync::thread;
15-
use crate::{Knobs, KnobsDatabase};
15+
use crate::KnobsDatabase;
1616

1717
const FALLBACK_A: u32 = 0b01;
1818
const FALLBACK_B: u32 = 0b10;
@@ -53,6 +53,9 @@ fn cycle_result_b(_db: &dyn KnobsDatabase) -> u32 {
5353
#[test_log::test]
5454
#[cfg(not(feature = "shuttle"))] // This test is currently failing.
5555
fn the_test() {
56+
use crate::sync::thread;
57+
use crate::Knobs;
58+
5659
crate::sync::check(|| {
5760
let db_t1 = Knobs::default();
5861
let db_t2 = db_t1.clone();

tests/parallel/cycle_nested_deep.rs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
//! The trick is that different threads call into the same cycle from different entry queries.
44
//!
55
//! * Thread 1: `a` -> b -> c (which calls back into d, e, b, a)
6-
//! * Thread 2: `d` -> `c`
7-
//! * Thread 3: `e` -> `c`
6+
//! * Thread 2: `b`
7+
//! * Thread 3: `d` -> `c`
8+
//! * Thread 4: `e` -> `c`
89
use crate::sync::thread;
910
use crate::{Knobs, KnobsDatabase};
1011

@@ -65,6 +66,7 @@ fn the_test() {
6566
let db_t1 = Knobs::default();
6667
let db_t2 = db_t1.clone();
6768
let db_t3 = db_t1.clone();
69+
let db_t4 = db_t1.clone();
6870

6971
let t1 = thread::spawn(move || {
7072
let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered();
@@ -73,11 +75,16 @@ fn the_test() {
7375
result
7476
});
7577
let t2 = thread::spawn(move || {
78+
let _span = tracing::debug_span!("t4", thread_id = ?thread::current().id()).entered();
79+
db_t4.wait_for(1);
80+
query_b(&db_t4)
81+
});
82+
let t3 = thread::spawn(move || {
7683
let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered();
7784
db_t2.wait_for(1);
7885
query_d(&db_t2)
7986
});
80-
let t3 = thread::spawn(move || {
87+
let t4 = thread::spawn(move || {
8188
let _span = tracing::debug_span!("t3", thread_id = ?thread::current().id()).entered();
8289
db_t3.wait_for(1);
8390
query_e(&db_t3)
@@ -86,7 +93,8 @@ fn the_test() {
8693
let r_t1 = t1.join().unwrap();
8794
let r_t2 = t2.join().unwrap();
8895
let r_t3 = t3.join().unwrap();
96+
let r_t4 = t4.join().unwrap();
8997

90-
assert_eq!((r_t1, r_t2, r_t3), (MAX, MAX, MAX));
98+
assert_eq!((r_t1, r_t2, r_t3, r_t4), (MAX, MAX, MAX, MAX));
9199
});
92100
}

0 commit comments

Comments
 (0)