diff --git a/Cargo.toml b/Cargo.toml index d64beb4..cef7138 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,7 @@ name = "async-executor" version = "1.8.0" authors = ["Stjepan Glavina "] edition = "2021" -rust-version = "1.60" +rust-version = "1.61" description = "Async executor" license = "Apache-2.0 OR MIT" repository = "https://github.com/smol-rs/async-executor" @@ -17,10 +17,12 @@ exclude = ["/.*"] [dependencies] async-lock = "3.0.0" async-task = "4.4.0" +atomic-waker = "1.0" concurrent-queue = "2.0.0" fastrand = "2.0.0" futures-lite = { version = "2.0.0", default-features = false } slab = "0.4.4" +thread_local = "1.1" [target.'cfg(target_family = "wasm")'.dependencies] futures-lite = { version = "2.0.0", default-features = false, features = ["std"] } diff --git a/benches/executor.rs b/benches/executor.rs index 20d41a1..b6e33c2 100644 --- a/benches/executor.rs +++ b/benches/executor.rs @@ -1,4 +1,3 @@ -use std::future::Future; use std::thread::available_parallelism; use async_executor::Executor; diff --git a/examples/priority.rs b/examples/priority.rs index df77dd1..60d5c9a 100644 --- a/examples/priority.rs +++ b/examples/priority.rs @@ -1,6 +1,5 @@ //! An executor with task priorities. -use std::future::Future; use std::thread; use async_executor::{Executor, Task}; diff --git a/src/lib.rs b/src/lib.rs index cafc6e6..05c69aa 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -34,19 +34,20 @@ )] use std::fmt; -use std::future::Future; use std::marker::PhantomData; use std::panic::{RefUnwindSafe, UnwindSafe}; use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::{Arc, Mutex, RwLock, TryLockError}; +use std::sync::{Arc, Mutex, TryLockError}; use std::task::{Poll, Waker}; use async_lock::OnceCell; use async_task::{Builder, Runnable}; +use atomic_waker::AtomicWaker; use concurrent_queue::ConcurrentQueue; use futures_lite::{future, prelude::*}; use slab::Slab; +use thread_local::ThreadLocal; #[doc(no_inline)] pub use async_task::Task; @@ -266,8 +267,23 @@ impl<'a> Executor<'a> { fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static { let state = self.state().clone(); - // TODO: If possible, push into the current local queue and notify the ticker. - move |runnable| { + move |mut runnable| { + // If possible, push into the current local queue and notify the ticker. + if let Some(local) = state.local_queue.get() { + runnable = if let Err(err) = local.queue.push(runnable) { + err.into_inner() + } else { + // Wake up this thread if it's asleep, otherwise notify another + // thread to try to have the task stolen. + if let Some(waker) = local.waker.take() { + waker.wake(); + } else { + state.notify(); + } + return; + } + } + // If the local queue is full, fallback to pushing onto the global injector queue. state.queue.push(runnable).unwrap(); state.notify(); } @@ -510,7 +526,16 @@ struct State { queue: ConcurrentQueue, /// Local queues created by runners. - local_queues: RwLock>>>, + /// + /// If possible, tasks are scheduled onto the local queue, and will only defer + /// to other global queue when they're full, or the task is being scheduled from + /// a thread without a runner. + /// + /// Note: if a runner terminates and drains its local queue, any subsequent + /// spawn calls from the same thread will be added to the same queue, but won't + /// be executed until `Executor::run` is run on the thread again, or another + /// thread steals the task. + local_queue: ThreadLocal, /// Set to `true` when a sleeping ticker is notified or no tickers are sleeping. notified: AtomicBool, @@ -527,7 +552,7 @@ impl State { fn new() -> State { State { queue: ConcurrentQueue::unbounded(), - local_queues: RwLock::new(Vec::new()), + local_queue: ThreadLocal::new(), notified: AtomicBool::new(true), sleepers: Mutex::new(Sleepers { count: 0, @@ -654,6 +679,12 @@ impl Ticker<'_> { /// /// Returns `false` if the ticker was already sleeping and unnotified. fn sleep(&mut self, waker: &Waker) -> bool { + self.state + .local_queue + .get_or_default() + .waker + .register(waker); + let mut sleepers = self.state.sleepers.lock().unwrap(); match self.sleeping { @@ -692,7 +723,14 @@ impl Ticker<'_> { /// Waits for the next runnable task to run. async fn runnable(&mut self) -> Runnable { - self.runnable_with(|| self.state.queue.pop().ok()).await + self.runnable_with(|| { + self.state + .local_queue + .get() + .and_then(|local| local.queue.pop().ok()) + .or_else(|| self.state.queue.pop().ok()) + }) + .await } /// Waits for the next runnable task to run, given a function that searches for a task. @@ -754,9 +792,6 @@ struct Runner<'a> { /// Inner ticker. ticker: Ticker<'a>, - /// The local queue. - local: Arc>, - /// Bumped every time a runnable task is found. ticks: usize, } @@ -767,38 +802,34 @@ impl Runner<'_> { let runner = Runner { state, ticker: Ticker::new(state), - local: Arc::new(ConcurrentQueue::bounded(512)), ticks: 0, }; - state - .local_queues - .write() - .unwrap() - .push(runner.local.clone()); runner } /// Waits for the next runnable task to run. async fn runnable(&mut self, rng: &mut fastrand::Rng) -> Runnable { + let local = self.state.local_queue.get_or_default(); + let runnable = self .ticker .runnable_with(|| { // Try the local queue. - if let Ok(r) = self.local.pop() { + if let Ok(r) = local.queue.pop() { return Some(r); } // Try stealing from the global queue. if let Ok(r) = self.state.queue.pop() { - steal(&self.state.queue, &self.local); + steal(&self.state.queue, &local.queue); return Some(r); } // Try stealing from other runners. - let local_queues = self.state.local_queues.read().unwrap(); + let local_queues = &self.state.local_queue; // Pick a random starting point in the iterator list and rotate the list. - let n = local_queues.len(); + let n = local_queues.iter().count(); let start = rng.usize(..n); let iter = local_queues .iter() @@ -807,12 +838,12 @@ impl Runner<'_> { .take(n); // Remove this runner's local queue. - let iter = iter.filter(|local| !Arc::ptr_eq(local, &self.local)); + let iter = iter.filter(|other| !core::ptr::eq(*other, local)); // Try stealing from each local queue in the list. - for local in iter { - steal(local, &self.local); - if let Ok(r) = self.local.pop() { + for other in iter { + steal(&other.queue, &local.queue); + if let Ok(r) = local.queue.pop() { return Some(r); } } @@ -826,7 +857,7 @@ impl Runner<'_> { if self.ticks % 64 == 0 { // Steal tasks from the global queue to ensure fair task scheduling. - steal(&self.state.queue, &self.local); + steal(&self.state.queue, &local.queue); } runnable @@ -836,15 +867,13 @@ impl Runner<'_> { impl Drop for Runner<'_> { fn drop(&mut self) { // Remove the local queue. - self.state - .local_queues - .write() - .unwrap() - .retain(|local| !Arc::ptr_eq(local, &self.local)); - - // Re-schedule remaining tasks in the local queue. - while let Ok(r) = self.local.pop() { - r.schedule(); + if let Some(local) = self.state.local_queue.get() { + // Re-schedule remaining tasks in the local queue. + for r in local.queue.try_iter() { + // Explicitly reschedule the runnable back onto the global + // queue to avoid rescheduling onto the local one. + self.state.queue.push(r).unwrap(); + } } } } @@ -904,18 +933,13 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_ } /// Debug wrapper for the local runners. - struct LocalRunners<'a>(&'a RwLock>>>); + struct LocalRunners<'a>(&'a ThreadLocal); impl fmt::Debug for LocalRunners<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self.0.try_read() { - Ok(lock) => f - .debug_list() - .entries(lock.iter().map(|queue| queue.len())) - .finish(), - Err(TryLockError::WouldBlock) => f.write_str(""), - Err(TryLockError::Poisoned(_)) => f.write_str(""), - } + f.debug_list() + .entries(self.0.iter().map(|local| local.queue.len())) + .finish() } } @@ -935,11 +959,32 @@ fn debug_executor(executor: &Executor<'_>, name: &str, f: &mut fmt::Formatter<'_ f.debug_struct(name) .field("active", &ActiveTasks(&state.active)) .field("global_tasks", &state.queue.len()) - .field("local_runners", &LocalRunners(&state.local_queues)) + .field("local_runners", &LocalRunners(&state.local_queue)) .field("sleepers", &SleepCount(&state.sleepers)) .finish() } +/// A queue local to each thread. +/// +/// It's Default implementation is used for initializing each +/// thread's queue via `ThreadLocal::get_or_default`. +/// +/// The local queue *must* be flushed, and all pending runnables +/// rescheduled onto the global queue when a runner is dropped. +struct LocalQueue { + queue: ConcurrentQueue, + waker: AtomicWaker, +} + +impl Default for LocalQueue { + fn default() -> Self { + Self { + queue: ConcurrentQueue::bounded(512), + waker: AtomicWaker::new(), + } + } +} + /// Runs a closure when dropped. struct CallOnDrop(F);