11//! A scoped tokio spawn implementation that allow a non-'static lifetime for tasks.
22
33use std:: {
4+ any:: Any ,
45 marker:: PhantomData ,
5- mem:: take,
66 panic:: { self , AssertUnwindSafe , catch_unwind} ,
77 pin:: Pin ,
88 sync:: {
99 Arc ,
1010 atomic:: { AtomicUsize , Ordering } ,
1111 } ,
12+ thread:: { self , Thread } ,
1213} ;
1314
15+ use futures:: FutureExt ;
1416use parking_lot:: Mutex ;
15- use tokio:: {
16- runtime:: { Builder , Handle } ,
17- task:: { JoinHandle , block_in_place} ,
18- } ;
19- use tracing:: { Instrument , Span } ;
17+ use tokio:: { runtime:: Handle , task:: block_in_place} ;
18+ use tracing:: { Instrument , Span , info_span} ;
2019
2120use crate :: {
2221 TurboTasksApi ,
2322 manager:: { try_turbo_tasks, turbo_tasks_future_scope} ,
2423} ;
2524
25+ struct ScopeInner {
26+ main_thread : Thread ,
27+ remaining_tasks : AtomicUsize ,
28+ /// The first panic that occurred in the tasks, by task index.
29+ /// The usize value is the index of the task.
30+ panic : Mutex < Option < ( Box < dyn Any + Send + ' static > , usize ) > > ,
31+ }
32+
33+ impl ScopeInner {
34+ fn on_task_finished ( & self , panic : Option < ( Box < dyn Any + Send + ' static > , usize ) > ) {
35+ if let Some ( ( err, index) ) = panic {
36+ let mut old_panic = self . panic . lock ( ) ;
37+ if old_panic. as_ref ( ) . is_none_or ( |& ( _, i) | i > index) {
38+ * old_panic = Some ( ( err, index) ) ;
39+ }
40+ }
41+ if self . remaining_tasks . fetch_sub ( 1 , Ordering :: Release ) == 1 {
42+ self . main_thread . unpark ( ) ;
43+ }
44+ }
45+
46+ fn wait ( & self ) {
47+ let _span = info_span ! ( "blocking" ) . entered ( ) ;
48+ while self . remaining_tasks . load ( Ordering :: Acquire ) != 0 {
49+ thread:: park ( ) ;
50+ }
51+ if let Some ( ( err, _) ) = self . panic . lock ( ) . take ( ) {
52+ panic:: resume_unwind ( err) ;
53+ }
54+ }
55+ }
56+
2657/// Scope to allow spawning tasks with a limited lifetime.
2758///
2859/// Dropping this Scope will wait for all tasks to complete.
2960pub struct Scope < ' scope , ' env : ' scope , R : Send + ' env > {
3061 results : & ' scope [ Mutex < Option < R > > ] ,
3162 index : AtomicUsize ,
32- futures : Mutex < Vec < JoinHandle < ( ) > > > ,
63+ inner : Arc < ScopeInner > ,
3364 handle : Handle ,
3465 turbo_tasks : Option < Arc < dyn TurboTasksApi > > ,
3566 span : Span ,
@@ -50,7 +81,11 @@ impl<'scope, 'env: 'scope, R: Send + 'env> Scope<'scope, 'env, R> {
5081 Self {
5182 results,
5283 index : AtomicUsize :: new ( 0 ) ,
53- futures : Mutex :: new ( Vec :: with_capacity ( results. len ( ) ) ) ,
84+ inner : Arc :: new ( ScopeInner {
85+ main_thread : thread:: current ( ) ,
86+ remaining_tasks : AtomicUsize :: new ( 0 ) ,
87+ panic : Mutex :: new ( None ) ,
88+ } ) ,
5489 handle : Handle :: current ( ) ,
5590 turbo_tasks : try_turbo_tasks ( ) ,
5691 span : Span :: current ( ) ,
@@ -86,57 +121,33 @@ impl<'scope, 'env: 'scope, R: Send + 'env> Scope<'scope, 'env, R> {
86121
87122 let turbo_tasks = self . turbo_tasks . clone ( ) ;
88123 let span = self . span . clone ( ) ;
89- let future = self . handle . spawn (
90- async move {
91- if let Some ( turbo_tasks) = turbo_tasks {
92- // Ensure that the turbo tasks context is maintained across the task.
93- turbo_tasks_future_scope ( turbo_tasks, f) . await ;
94- } else {
95- // If no turbo tasks context is available, just run the future.
96- f. await ;
97- }
98- }
99- . instrument ( span) ,
100- ) ;
101- self . futures . lock ( ) . push ( future) ;
102- }
103124
104- /// Blocks the current thread until all spawned tasks have completed.
105- fn block_until_complete ( & self ) {
106- let futures = take ( & mut * self . futures . lock ( ) ) ;
107- if futures. is_empty ( ) {
108- return ; // No tasks to wait for, return early
109- }
110- // We create a new current thread runtime to be independent of the current tokio runtime.
111- // This makes us not subject to runtime shutdown and we can drive the futures to completion
112- // in all cases.
113- Builder :: new_current_thread ( ) . build ( ) . unwrap ( ) . block_on (
114- async {
115- let mut first_err = None ;
116- for task in futures {
117- match task. await {
118- Ok ( _) => { }
119- Err ( err) if first_err. is_none ( ) => {
120- // SAFETY: We need to finish all futures before panicking.
121- first_err = Some ( err) ;
122- }
123- Err ( _) => {
124- // Ignore subsequent errors
125- }
125+ let inner = self . inner . clone ( ) ;
126+ inner. remaining_tasks . fetch_add ( 1 , Ordering :: Relaxed ) ;
127+ self . handle . spawn ( async move {
128+ let result = AssertUnwindSafe (
129+ async move {
130+ if let Some ( turbo_tasks) = turbo_tasks {
131+ // Ensure that the turbo tasks context is maintained across the task.
132+ turbo_tasks_future_scope ( turbo_tasks, f) . await ;
133+ } else {
134+ // If no turbo tasks context is available, just run the future.
135+ f. await ;
126136 }
127137 }
128- if let Some ( err) = first_err {
129- panic:: resume_unwind ( err. into_panic ( ) ) ;
130- }
131- }
132- . instrument ( self . span . clone ( ) ) ,
133- ) ;
138+ . instrument ( span) ,
139+ )
140+ . catch_unwind ( )
141+ . await ;
142+ let panic = result. err ( ) . map ( |e| ( e, index) ) ;
143+ inner. on_task_finished ( panic) ;
144+ } ) ;
134145 }
135146}
136147
137148impl < ' scope , ' env : ' scope , R : Send + ' env > Drop for Scope < ' scope , ' env , R > {
138149 fn drop ( & mut self ) {
139- self . block_until_complete ( ) ;
150+ self . inner . wait ( ) ;
140151 }
141152}
142153
@@ -192,6 +203,47 @@ mod tests {
192203 } ) ;
193204 }
194205
206+ #[ tokio:: test( flavor = "multi_thread" ) ]
207+ async fn test_empty_scope ( ) {
208+ let results = scope_and_block ( 0 , |scope| {
209+ if false {
210+ scope. spawn ( async move { 42 } ) ;
211+ }
212+ } ) ;
213+ assert_eq ! ( results. count( ) , 0 ) ;
214+ }
215+
216+ #[ tokio:: test( flavor = "multi_thread" ) ]
217+ async fn test_single_task ( ) {
218+ let results = scope_and_block ( 1 , |scope| {
219+ scope. spawn ( async move { 42 } ) ;
220+ } )
221+ . collect :: < Vec < _ > > ( ) ;
222+ assert_eq ! ( results, vec![ 42 ] ) ;
223+ }
224+
225+ #[ tokio:: test( flavor = "multi_thread" ) ]
226+ async fn test_task_finish_before_scope ( ) {
227+ let results = scope_and_block ( 1 , |scope| {
228+ scope. spawn ( async move { 42 } ) ;
229+ thread:: sleep ( std:: time:: Duration :: from_millis ( 100 ) ) ;
230+ } )
231+ . collect :: < Vec < _ > > ( ) ;
232+ assert_eq ! ( results, vec![ 42 ] ) ;
233+ }
234+
235+ #[ tokio:: test( flavor = "multi_thread" ) ]
236+ async fn test_task_finish_after_scope ( ) {
237+ let results = scope_and_block ( 1 , |scope| {
238+ scope. spawn ( async move {
239+ thread:: sleep ( std:: time:: Duration :: from_millis ( 100 ) ) ;
240+ 42
241+ } ) ;
242+ } )
243+ . collect :: < Vec < _ > > ( ) ;
244+ assert_eq ! ( results, vec![ 42 ] ) ;
245+ }
246+
195247 #[ tokio:: test( flavor = "multi_thread" ) ]
196248 async fn test_panic_in_scope_factory ( ) {
197249 let result = catch_unwind ( AssertUnwindSafe ( || {
0 commit comments