diff --git a/Cargo.toml b/Cargo.toml index 1bf97de..f3fb95e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] async-task = "4.7" +pin-project = "1" [dev-dependencies] tokio = { version = "1", features = ["full"] } diff --git a/src/droppable_future.rs b/src/droppable_future.rs new file mode 100644 index 0000000..0ab6d57 --- /dev/null +++ b/src/droppable_future.rs @@ -0,0 +1,51 @@ +use std::{future::Future, pin::Pin}; + +use pin_project::{pin_project, pinned_drop}; + +#[pin_project(PinnedDrop)] +pub struct DroppableFuture +where + F: Future, + D: Fn(), +{ + #[pin] + future: F, + on_drop: D, +} + +impl DroppableFuture +where + F: Future, + D: Fn(), +{ + pub fn new(future: F, on_drop: D) -> Self { + Self { future, on_drop } + } +} + +impl Future for DroppableFuture +where + F: Future, + D: Fn(), +{ + type Output = F::Output; + + fn poll( + self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + let this = self.project(); + this.future.poll(cx) + } +} + +#[pinned_drop] +impl PinnedDrop for DroppableFuture +where + F: Future, + D: Fn(), +{ + fn drop(self: Pin<&mut Self>) { + (self.on_drop)(); + } +} diff --git a/src/lib.rs b/src/lib.rs index 86ac703..1e5a011 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,6 @@ +mod droppable_future; +use droppable_future::*; + mod task_identifier; pub use task_identifier::*; diff --git a/src/ticked_async_executor.rs b/src/ticked_async_executor.rs index 1ade368..70c3037 100644 --- a/src/ticked_async_executor.rs +++ b/src/ticked_async_executor.rs @@ -6,7 +6,7 @@ use std::{ }, }; -use crate::TaskIdentifier; +use crate::{DroppableFuture, TaskIdentifier}; #[derive(Debug)] pub enum TaskState { @@ -16,37 +16,11 @@ pub enum TaskState { Drop(TaskIdentifier), } -pub type Task = async_task::Task>; -type TaskRunnable = async_task::Runnable>; -type Payload = (TaskIdentifier, TaskRunnable); +pub type Task = async_task::Task; +type Payload = (TaskIdentifier, async_task::Runnable); -/// Task Metadata associated with TickedAsyncExecutor -/// -/// Primarily used to track when the Task is completed/cancelled -pub struct TaskMetadata -where - O: Fn(TaskState) + Send + Sync + 'static, -{ - num_spawned_tasks: Arc, - identifier: TaskIdentifier, - observer: O, -} - -impl Drop for TaskMetadata -where - O: Fn(TaskState) + Send + Sync + 'static, -{ - fn drop(&mut self) { - self.num_spawned_tasks.fetch_sub(1, Ordering::Relaxed); - (self.observer)(TaskState::Drop(self.identifier.clone())); - } -} - -pub struct TickedAsyncExecutor -where - O: Fn(TaskState) + Send + Sync + 'static, -{ - channel: (mpsc::Sender>, mpsc::Receiver>), +pub struct TickedAsyncExecutor { + channel: (mpsc::Sender, mpsc::Receiver), num_woken_tasks: Arc, num_spawned_tasks: Arc, @@ -79,22 +53,14 @@ where &self, identifier: impl Into, future: impl Future + Send + 'static, - ) -> Task + ) -> Task where T: Send + 'static, { let identifier = identifier.into(); - self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed); - (self.observer)(TaskState::Spawn(identifier.clone())); - - let schedule = self.runnable_schedule_cb(identifier.clone()); - let (runnable, task) = async_task::Builder::new() - .metadata(TaskMetadata { - num_spawned_tasks: self.num_spawned_tasks.clone(), - identifier, - observer: self.observer.clone(), - }) - .spawn(|_m| future, schedule); + let future = self.droppable_future(identifier.clone(), future); + let schedule = self.runnable_schedule_cb(identifier); + let (runnable, task) = async_task::spawn(future, schedule); runnable.schedule(); task } @@ -103,22 +69,14 @@ where &self, identifier: impl Into, future: impl Future + 'static, - ) -> Task + ) -> Task where T: 'static, { let identifier = identifier.into(); - self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed); - (self.observer)(TaskState::Spawn(identifier.clone())); - - let schedule = self.runnable_schedule_cb(identifier.clone()); - let (runnable, task) = async_task::Builder::new() - .metadata(TaskMetadata { - num_spawned_tasks: self.num_spawned_tasks.clone(), - identifier, - observer: self.observer.clone(), - }) - .spawn_local(move |_m| future, schedule); + let future = self.droppable_future(identifier.clone(), future); + let schedule = self.runnable_schedule_cb(identifier); + let (runnable, task) = async_task::spawn_local(future, schedule); runnable.schedule(); task } @@ -146,7 +104,29 @@ where .fetch_sub(num_woken_tasks, Ordering::Relaxed); } - fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(TaskRunnable) { + fn droppable_future( + &self, + identifier: TaskIdentifier, + future: F, + ) -> DroppableFuture + where + F: Future, + { + let observer = self.observer.clone(); + + // Spawn Task + self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed); + observer(TaskState::Spawn(identifier.clone())); + + // Droppable Future registering on_drop callback + let num_spawned_tasks = self.num_spawned_tasks.clone(); + DroppableFuture::new(future, move || { + num_spawned_tasks.fetch_sub(1, Ordering::Relaxed); + observer(TaskState::Drop(identifier.clone())); + }) + } + + fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) { let sender = self.channel.0.clone(); let num_woken_tasks = self.num_woken_tasks.clone(); let observer = self.observer.clone(); @@ -160,15 +140,13 @@ where #[cfg(test)] mod tests { - use tokio::join; - use super::*; #[test] - fn test_multiple_local_tasks() { + fn test_multiple_tasks() { let executor = TickedAsyncExecutor::default(); executor - .spawn_local("A", async move { + .spawn("A", async move { tokio::task::yield_now().await; }) .detach(); @@ -187,7 +165,7 @@ mod tests { } #[test] - fn test_local_tasks_cancellation() { + fn test_task_cancellation() { let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}")); let task1 = executor.spawn_local("A", async move { loop { @@ -205,39 +183,7 @@ mod tests { executor .spawn_local("CancelTasks", async move { - let (t1, t2) = join!(task1.cancel(), task2.cancel()); - assert_eq!(t1, None); - assert_eq!(t2, None); - }) - .detach(); - assert_eq!(executor.num_tasks(), 3); - - // Since we have cancelled the tasks above, the loops should eventually end - while executor.num_tasks() != 0 { - executor.tick(); - } - } - - #[test] - fn test_tasks_cancellation() { - let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}")); - let task1 = executor.spawn("A", async move { - loop { - tokio::task::yield_now().await; - } - }); - - let task2 = executor.spawn(format!("B"), async move { - loop { - tokio::task::yield_now().await; - } - }); - assert_eq!(executor.num_tasks(), 2); - executor.tick(); - - executor - .spawn_local("CancelTasks", async move { - let (t1, t2) = join!(task1.cancel(), task2.cancel()); + let (t1, t2) = tokio::join!(task1.cancel(), task2.cancel()); assert_eq!(t1, None); assert_eq!(t2, None); }) diff --git a/tests/tokio_tests.rs b/tests/tokio_tests.rs new file mode 100644 index 0000000..6b1db77 --- /dev/null +++ b/tests/tokio_tests.rs @@ -0,0 +1,79 @@ +use ticked_async_executor::TickedAsyncExecutor; + +#[test] +fn test_tokio_join() { + let executor = TickedAsyncExecutor::default(); + + let (tx1, mut rx1) = tokio::sync::mpsc::channel::(1); + let (tx2, mut rx2) = tokio::sync::mpsc::channel::(1); + executor + .spawn("ThreadedFuture", async move { + let (a, b) = tokio::join!(rx1.recv(), rx2.recv()); + assert_eq!(a.unwrap(), 10); + assert_eq!(b.unwrap(), 20); + }) + .detach(); + + let (tx3, mut rx3) = tokio::sync::mpsc::channel::(1); + let (tx4, mut rx4) = tokio::sync::mpsc::channel::(1); + executor + .spawn("LocalFuture", async move { + let (a, b) = tokio::join!(rx3.recv(), rx4.recv()); + assert_eq!(a.unwrap(), 10); + assert_eq!(b.unwrap(), 20); + }) + .detach(); + + tx1.try_send(10).unwrap(); + tx3.try_send(10).unwrap(); + for _ in 0..10 { + executor.tick(); + } + tx2.try_send(20).unwrap(); + tx4.try_send(20).unwrap(); + + while executor.num_tasks() != 0 { + executor.tick(); + } +} + +#[test] +fn test_tokio_select() { + let executor = TickedAsyncExecutor::default(); + + let (tx1, mut rx1) = tokio::sync::mpsc::channel::(1); + let (_tx2, mut rx2) = tokio::sync::mpsc::channel::(1); + executor + .spawn("ThreadedFuture", async move { + tokio::select! { + data = rx1.recv() => { + assert_eq!(data.unwrap(), 10); + } + _ = rx2.recv() => {} + } + }) + .detach(); + + let (tx3, mut rx3) = tokio::sync::mpsc::channel::(1); + let (_tx4, mut rx4) = tokio::sync::mpsc::channel::(1); + executor + .spawn("LocalFuture", async move { + tokio::select! { + data = rx3.recv() => { + assert_eq!(data.unwrap(), 10); + } + _ = rx4.recv() => {} + } + }) + .detach(); + + for _ in 0..10 { + executor.tick(); + } + + tx1.try_send(10).unwrap(); + tx3.try_send(10).unwrap(); + while executor.num_tasks() != 0 { + executor.tick(); + } +}