@@ -14,14 +14,14 @@ use tracing::log::trace;
14
14
pub struct Walker < N , S > {
15
15
marker : std:: marker:: PhantomData < S > ,
16
16
cancel : watch:: Sender < bool > ,
17
- node_events : Option < mpsc:: Receiver < ( N , oneshot:: Sender < ( ) > ) > > ,
17
+ node_events : Option < mpsc:: Receiver < ( N , oneshot:: Sender < bool > ) > > ,
18
18
join_handles : FuturesUnordered < JoinHandle < ( ) > > ,
19
19
}
20
20
21
21
pub struct Start ;
22
22
pub struct Walking ;
23
23
24
- pub type WalkMessage < N > = ( N , oneshot:: Sender < ( ) > ) ;
24
+ pub type WalkMessage < N > = ( N , oneshot:: Sender < bool > ) ;
25
25
26
26
// These constraint might look very stiff, but since all of the petgraph graph
27
27
// types use integers as node ids and GraphBase already constraints these types
@@ -37,7 +37,7 @@ impl<N: Eq + Hash + Copy + Send + 'static> Walker<N, Start> {
37
37
let mut rxs = HashMap :: new ( ) ;
38
38
for node in graph. node_identifiers ( ) {
39
39
// Each node can finish at most once so we set the capacity to 1
40
- let ( tx, rx) = broadcast:: channel :: < ( ) > ( 1 ) ;
40
+ let ( tx, rx) = broadcast:: channel :: < bool > ( 1 ) ;
41
41
txs. insert ( node, tx) ;
42
42
rxs. insert ( node, rx) ;
43
43
}
@@ -76,8 +76,14 @@ impl<N: Eq + Hash + Copy + Send + 'static> Walker<N, Start> {
76
76
results = deps_fut => {
77
77
for res in results {
78
78
match res {
79
- // No errors from reading dependency channels
80
- Ok ( ( ) ) => ( ) ,
79
+ // Dependency channel signaled this subgraph is terminal;
80
+ // let our dependents know too (if any)
81
+ Ok ( false ) => {
82
+ tx. send( false ) . ok( ) ;
83
+ return ;
84
+ }
85
+ // Otherwise continue
86
+ Ok ( true ) => ( ) ,
81
87
// A dependency finished without sending a finish
82
88
// Could happen if a cancel is sent and is racing with deps
83
89
// so we interpret this as a cancel.
@@ -95,7 +101,7 @@ impl<N: Eq + Hash + Copy + Send + 'static> Walker<N, Start> {
95
101
}
96
102
}
97
103
98
- let ( callback_tx, callback_rx) = oneshot:: channel:: <( ) >( ) ;
104
+ let ( callback_tx, callback_rx) = oneshot:: channel:: <bool >( ) ;
99
105
// do some err handling with the send failure?
100
106
if node_tx. send( ( node, callback_tx) ) . await . is_err( ) {
101
107
// Receiving end of node channel has been closed/dropped
@@ -104,14 +110,15 @@ impl<N: Eq + Hash + Copy + Send + 'static> Walker<N, Start> {
104
110
trace!( "Receiver was dropped before walk finished without calling cancel" ) ;
105
111
return ;
106
112
}
107
- if callback_rx. await . is_err ( ) {
113
+ let Ok ( callback_result ) = callback_rx. await else {
108
114
// If the caller drops the callback sender without signaling
109
115
// that the node processing is finished we assume that it is finished.
110
- trace!( "Callback sender was dropped without sending a finish signal" )
111
- }
116
+ trace!( "Callback sender was dropped without sending a finish signal" ) ;
117
+ return ;
118
+ } ;
112
119
// Send errors indicate that there are no receivers which
113
120
// happens when this node has no dependents
114
- tx. send( ( ) ) . ok( ) ;
121
+ tx. send( callback_result ) . ok( ) ;
115
122
}
116
123
}
117
124
} ) ) ;
@@ -204,7 +211,7 @@ mod test {
204
211
let ( walker, mut node_emitter) = walker. walk ( ) ;
205
212
while let Some ( ( index, done) ) = node_emitter. recv ( ) . await {
206
213
visited. push ( index) ;
207
- done. send ( ( ) ) . unwrap ( ) ;
214
+ done. send ( true ) . unwrap ( ) ;
208
215
}
209
216
walker. wait ( ) . await . unwrap ( ) ;
210
217
assert_eq ! ( visited, vec![ c, b, a] ) ;
@@ -228,7 +235,7 @@ mod test {
228
235
walker. cancel ( ) . unwrap ( ) ;
229
236
230
237
visited. push ( index) ;
231
- done. send ( ( ) ) . unwrap ( ) ;
238
+ done. send ( true ) . unwrap ( ) ;
232
239
}
233
240
assert_eq ! ( visited, vec![ c] ) ;
234
241
let Walker { join_handles, .. } = walker;
@@ -272,16 +279,16 @@ mod test {
272
279
tokio:: spawn ( async move {
273
280
is_b_done. await . unwrap ( ) ;
274
281
visited. lock ( ) . unwrap ( ) . push ( index) ;
275
- done. send ( ( ) ) . unwrap ( ) ;
282
+ done. send ( true ) . unwrap ( ) ;
276
283
} ) ;
277
284
} else if index == b {
278
285
// send the signal that b is finished
279
286
visited. lock ( ) . unwrap ( ) . push ( index) ;
280
- done. send ( ( ) ) . unwrap ( ) ;
287
+ done. send ( true ) . unwrap ( ) ;
281
288
b_done. take ( ) . unwrap ( ) . send ( ( ) ) . unwrap ( ) ;
282
289
} else {
283
290
visited. lock ( ) . unwrap ( ) . push ( index) ;
284
- done. send ( ( ) ) . unwrap ( ) ;
291
+ done. send ( true ) . unwrap ( ) ;
285
292
}
286
293
}
287
294
walker. wait ( ) . await . unwrap ( ) ;
@@ -322,12 +329,12 @@ mod test {
322
329
tokio:: spawn ( async move {
323
330
is_b_done. await . unwrap ( ) ;
324
331
visited. lock ( ) . unwrap ( ) . push ( index) ;
325
- done. send ( ( ) ) . unwrap ( ) ;
332
+ done. send ( true ) . unwrap ( ) ;
326
333
} ) ;
327
334
} else if index == b {
328
335
// send the signal that b is finished
329
336
visited. lock ( ) . unwrap ( ) . push ( index) ;
330
- done. send ( ( ) ) . unwrap ( ) ;
337
+ done. send ( true ) . unwrap ( ) ;
331
338
b_done. take ( ) . unwrap ( ) . send ( ( ) ) . unwrap ( ) ;
332
339
} else if index == a {
333
340
// don't mark as done until d finishes
@@ -336,19 +343,69 @@ mod test {
336
343
tokio:: spawn ( async move {
337
344
is_d_done. await . unwrap ( ) ;
338
345
visited. lock ( ) . unwrap ( ) . push ( index) ;
339
- done. send ( ( ) ) . unwrap ( ) ;
346
+ done. send ( true ) . unwrap ( ) ;
340
347
} ) ;
341
348
} else if index == d {
342
349
// send the signal that b is finished
343
350
visited. lock ( ) . unwrap ( ) . push ( index) ;
344
- done. send ( ( ) ) . unwrap ( ) ;
351
+ done. send ( true ) . unwrap ( ) ;
345
352
d_done. take ( ) . unwrap ( ) . send ( ( ) ) . unwrap ( ) ;
346
353
} else {
347
354
visited. lock ( ) . unwrap ( ) . push ( index) ;
348
- done. send ( ( ) ) . unwrap ( ) ;
355
+ done. send ( true ) . unwrap ( ) ;
349
356
}
350
357
}
351
358
walker. wait ( ) . await . unwrap ( ) ;
352
359
assert_eq ! ( visited. lock( ) . unwrap( ) . as_slice( ) , & [ c, b, e, d, a] ) ;
353
360
}
361
+
362
+ #[ tokio:: test]
363
+ async fn test_dependent_cancellation ( ) {
364
+ // a -- b -- c -- f
365
+ // \ /
366
+ // - d -- e -
367
+ let mut g = Graph :: new ( ) ;
368
+ let a = g. add_node ( "a" ) ;
369
+ let b = g. add_node ( "b" ) ;
370
+ let c = g. add_node ( "c" ) ;
371
+ let d = g. add_node ( "d" ) ;
372
+ let e = g. add_node ( "e" ) ;
373
+ let f = g. add_node ( "f" ) ;
374
+ g. add_edge ( a, b, ( ) ) ;
375
+ g. add_edge ( b, c, ( ) ) ;
376
+ g. add_edge ( a, d, ( ) ) ;
377
+ g. add_edge ( d, e, ( ) ) ;
378
+ g. add_edge ( c, f, ( ) ) ;
379
+ g. add_edge ( e, f, ( ) ) ;
380
+
381
+ // We intentionally wait to mark c as finished until e has been finished
382
+ let walker = Walker :: new ( & g) ;
383
+ let visited = Arc :: new ( Mutex :: new ( Vec :: new ( ) ) ) ;
384
+ let ( walker, mut node_emitter) = walker. walk ( ) ;
385
+ let ( e_done, is_e_done) = oneshot:: channel :: < ( ) > ( ) ;
386
+ let mut e_done = Some ( e_done) ;
387
+ let mut is_e_done = Some ( is_e_done) ;
388
+ while let Some ( ( index, done) ) = node_emitter. recv ( ) . await {
389
+ if index == c {
390
+ // don't mark as done until we get the signal that e is finished
391
+ let is_e_done = is_e_done. take ( ) . unwrap ( ) ;
392
+ let visited = visited. clone ( ) ;
393
+ tokio:: spawn ( async move {
394
+ is_e_done. await . unwrap ( ) ;
395
+ visited. lock ( ) . unwrap ( ) . push ( index) ;
396
+ done. send ( true ) . unwrap ( ) ;
397
+ } ) ;
398
+ } else if index == e {
399
+ // send the signal that e is finished, and cancel its dependents
400
+ visited. lock ( ) . unwrap ( ) . push ( index) ;
401
+ done. send ( false ) . unwrap ( ) ;
402
+ e_done. take ( ) . unwrap ( ) . send ( ( ) ) . unwrap ( ) ;
403
+ } else {
404
+ visited. lock ( ) . unwrap ( ) . push ( index) ;
405
+ done. send ( true ) . unwrap ( ) ;
406
+ }
407
+ }
408
+ walker. wait ( ) . await . unwrap ( ) ;
409
+ assert_eq ! ( visited. lock( ) . unwrap( ) . as_slice( ) , & [ f, e, c, b] ) ;
410
+ }
354
411
}
0 commit comments