Skip to content

Commit 144a4e5

Browse files
committed
Add a test for a cycle with changing cycle heads
1 parent 5014819 commit 144a4e5

File tree

1 file changed

+110
-0
lines changed

1 file changed

+110
-0
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
//! Test a deeply nested-cycle scenario where cycles have changing query dependencies.
2+
//!
3+
//! The trick is that different threads call into the same cycle from different entry queries and
4+
//! the cycle heads change over different iterations
5+
//!
6+
//! * Thread 1: `a` -> b -> c
7+
//! * Thread 2: `b`
8+
//! * Thread 3: `d` -> `c`
9+
//! * Thread 4: `e` -> `c`
10+
//!
11+
//! `c` calls:
12+
//! * `d` and `a` in the first few iterations
13+
//! * `d`, `b` and `e` in the last iterations
14+
use crate::sync::thread;
15+
use crate::{Knobs, KnobsDatabase};
16+
17+
use salsa::CycleRecoveryAction;
18+
19+
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, salsa::Update)]
20+
struct CycleValue(u32);
21+
22+
const MIN: CycleValue = CycleValue(0);
23+
const MAX: CycleValue = CycleValue(3);
24+
25+
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)]
26+
fn query_a(db: &dyn KnobsDatabase) -> CycleValue {
27+
query_b(db)
28+
}
29+
30+
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)]
31+
fn query_b(db: &dyn KnobsDatabase) -> CycleValue {
32+
let c_value = query_c(db);
33+
CycleValue(c_value.0 + 1).min(MAX)
34+
}
35+
36+
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)]
37+
fn query_c(db: &dyn KnobsDatabase) -> CycleValue {
38+
let d_value = query_d(db);
39+
40+
if d_value > CycleValue(0) {
41+
let e_value = query_e(db);
42+
let b_value = query_b(db);
43+
CycleValue(d_value.0.max(e_value.0).max(b_value.0))
44+
} else {
45+
let a_value = query_a(db);
46+
CycleValue(d_value.0.max(a_value.0))
47+
}
48+
}
49+
50+
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)]
51+
fn query_d(db: &dyn KnobsDatabase) -> CycleValue {
52+
query_c(db)
53+
}
54+
55+
#[salsa::tracked(cycle_fn=cycle_fn, cycle_initial=initial)]
56+
fn query_e(db: &dyn KnobsDatabase) -> CycleValue {
57+
query_c(db)
58+
}
59+
60+
fn cycle_fn(
61+
_db: &dyn KnobsDatabase,
62+
_value: &CycleValue,
63+
_count: u32,
64+
) -> CycleRecoveryAction<CycleValue> {
65+
CycleRecoveryAction::Iterate
66+
}
67+
68+
fn initial(_db: &dyn KnobsDatabase) -> CycleValue {
69+
MIN
70+
}
71+
72+
#[test_log::test]
73+
fn the_test() {
74+
crate::sync::check(|| {
75+
tracing::debug!("New run");
76+
let db_t1 = Knobs::default();
77+
let db_t2 = db_t1.clone();
78+
let db_t3 = db_t1.clone();
79+
let db_t4 = db_t1.clone();
80+
81+
let t1 = thread::spawn(move || {
82+
let _span = tracing::debug_span!("t1", thread_id = ?thread::current().id()).entered();
83+
let result = query_a(&db_t1);
84+
db_t1.signal(1);
85+
result
86+
});
87+
let t2 = thread::spawn(move || {
88+
let _span = tracing::debug_span!("t4", thread_id = ?thread::current().id()).entered();
89+
db_t4.wait_for(1);
90+
query_b(&db_t4)
91+
});
92+
let t3 = thread::spawn(move || {
93+
let _span = tracing::debug_span!("t2", thread_id = ?thread::current().id()).entered();
94+
db_t2.wait_for(1);
95+
query_d(&db_t2)
96+
});
97+
let t4 = thread::spawn(move || {
98+
let _span = tracing::debug_span!("t3", thread_id = ?thread::current().id()).entered();
99+
db_t3.wait_for(1);
100+
query_e(&db_t3)
101+
});
102+
103+
let r_t1 = t1.join().unwrap();
104+
let r_t2 = t2.join().unwrap();
105+
let r_t3 = t3.join().unwrap();
106+
let r_t4 = t4.join().unwrap();
107+
108+
assert_eq!((r_t1, r_t2, r_t3, r_t4), (MAX, MAX, MAX, MAX));
109+
});
110+
}

0 commit comments

Comments
 (0)