Skip to content

Commit

Permalink
Use Droppable Future instead of TaskMetadata (#4)
Browse files Browse the repository at this point in the history
- TaskMetadata does not Drop when using tokio::select
- Added tokio integration tests for join! and select!
  • Loading branch information
coder137 authored Jul 20, 2024
1 parent 07eda0e commit f22d942
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 94 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ edition = "2021"

[dependencies]
async-task = "4.7"
pin-project = "1"

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
51 changes: 51 additions & 0 deletions src/droppable_future.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use std::{future::Future, pin::Pin};

use pin_project::{pin_project, pinned_drop};

#[pin_project(PinnedDrop)]
pub struct DroppableFuture<F, D>
where
F: Future,
D: Fn(),
{
#[pin]
future: F,
on_drop: D,
}

impl<F, D> DroppableFuture<F, D>
where
F: Future,
D: Fn(),
{
pub fn new(future: F, on_drop: D) -> Self {
Self { future, on_drop }
}
}

impl<F, D> Future for DroppableFuture<F, D>
where
F: Future,
D: Fn(),
{
type Output = F::Output;

fn poll(
self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
let this = self.project();
this.future.poll(cx)
}
}

#[pinned_drop]
impl<F, D> PinnedDrop for DroppableFuture<F, D>
where
F: Future,
D: Fn(),
{
fn drop(self: Pin<&mut Self>) {
(self.on_drop)();
}
}
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
mod droppable_future;
use droppable_future::*;

mod task_identifier;
pub use task_identifier::*;

Expand Down
134 changes: 40 additions & 94 deletions src/ticked_async_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use std::{
},
};

use crate::TaskIdentifier;
use crate::{DroppableFuture, TaskIdentifier};

#[derive(Debug)]
pub enum TaskState {
Expand All @@ -16,37 +16,11 @@ pub enum TaskState {
Drop(TaskIdentifier),
}

pub type Task<T, O> = async_task::Task<T, TaskMetadata<O>>;
type TaskRunnable<O> = async_task::Runnable<TaskMetadata<O>>;
type Payload<O> = (TaskIdentifier, TaskRunnable<O>);
pub type Task<T> = async_task::Task<T>;
type Payload = (TaskIdentifier, async_task::Runnable);

/// Task Metadata associated with TickedAsyncExecutor
///
/// Primarily used to track when the Task is completed/cancelled
pub struct TaskMetadata<O>
where
O: Fn(TaskState) + Send + Sync + 'static,
{
num_spawned_tasks: Arc<AtomicUsize>,
identifier: TaskIdentifier,
observer: O,
}

impl<O> Drop for TaskMetadata<O>
where
O: Fn(TaskState) + Send + Sync + 'static,
{
fn drop(&mut self) {
self.num_spawned_tasks.fetch_sub(1, Ordering::Relaxed);
(self.observer)(TaskState::Drop(self.identifier.clone()));
}
}

pub struct TickedAsyncExecutor<O>
where
O: Fn(TaskState) + Send + Sync + 'static,
{
channel: (mpsc::Sender<Payload<O>>, mpsc::Receiver<Payload<O>>),
pub struct TickedAsyncExecutor<O> {
channel: (mpsc::Sender<Payload>, mpsc::Receiver<Payload>),
num_woken_tasks: Arc<AtomicUsize>,
num_spawned_tasks: Arc<AtomicUsize>,

Expand Down Expand Up @@ -79,22 +53,14 @@ where
&self,
identifier: impl Into<TaskIdentifier>,
future: impl Future<Output = T> + Send + 'static,
) -> Task<T, O>
) -> Task<T>
where
T: Send + 'static,
{
let identifier = identifier.into();
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
(self.observer)(TaskState::Spawn(identifier.clone()));

let schedule = self.runnable_schedule_cb(identifier.clone());
let (runnable, task) = async_task::Builder::new()
.metadata(TaskMetadata {
num_spawned_tasks: self.num_spawned_tasks.clone(),
identifier,
observer: self.observer.clone(),
})
.spawn(|_m| future, schedule);
let future = self.droppable_future(identifier.clone(), future);
let schedule = self.runnable_schedule_cb(identifier);
let (runnable, task) = async_task::spawn(future, schedule);
runnable.schedule();
task
}
Expand All @@ -103,22 +69,14 @@ where
&self,
identifier: impl Into<TaskIdentifier>,
future: impl Future<Output = T> + 'static,
) -> Task<T, O>
) -> Task<T>
where
T: 'static,
{
let identifier = identifier.into();
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
(self.observer)(TaskState::Spawn(identifier.clone()));

let schedule = self.runnable_schedule_cb(identifier.clone());
let (runnable, task) = async_task::Builder::new()
.metadata(TaskMetadata {
num_spawned_tasks: self.num_spawned_tasks.clone(),
identifier,
observer: self.observer.clone(),
})
.spawn_local(move |_m| future, schedule);
let future = self.droppable_future(identifier.clone(), future);
let schedule = self.runnable_schedule_cb(identifier);
let (runnable, task) = async_task::spawn_local(future, schedule);
runnable.schedule();
task
}
Expand Down Expand Up @@ -146,7 +104,29 @@ where
.fetch_sub(num_woken_tasks, Ordering::Relaxed);
}

fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(TaskRunnable<O>) {
fn droppable_future<F>(
&self,
identifier: TaskIdentifier,
future: F,
) -> DroppableFuture<F, impl Fn()>
where
F: Future,
{
let observer = self.observer.clone();

// Spawn Task
self.num_spawned_tasks.fetch_add(1, Ordering::Relaxed);
observer(TaskState::Spawn(identifier.clone()));

// Droppable Future registering on_drop callback
let num_spawned_tasks = self.num_spawned_tasks.clone();
DroppableFuture::new(future, move || {
num_spawned_tasks.fetch_sub(1, Ordering::Relaxed);
observer(TaskState::Drop(identifier.clone()));
})
}

fn runnable_schedule_cb(&self, identifier: TaskIdentifier) -> impl Fn(async_task::Runnable) {
let sender = self.channel.0.clone();
let num_woken_tasks = self.num_woken_tasks.clone();
let observer = self.observer.clone();
Expand All @@ -160,15 +140,13 @@ where

#[cfg(test)]
mod tests {
use tokio::join;

use super::*;

#[test]
fn test_multiple_local_tasks() {
fn test_multiple_tasks() {
let executor = TickedAsyncExecutor::default();
executor
.spawn_local("A", async move {
.spawn("A", async move {
tokio::task::yield_now().await;
})
.detach();
Expand All @@ -187,7 +165,7 @@ mod tests {
}

#[test]
fn test_local_tasks_cancellation() {
fn test_task_cancellation() {
let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}"));
let task1 = executor.spawn_local("A", async move {
loop {
Expand All @@ -205,39 +183,7 @@ mod tests {

executor
.spawn_local("CancelTasks", async move {
let (t1, t2) = join!(task1.cancel(), task2.cancel());
assert_eq!(t1, None);
assert_eq!(t2, None);
})
.detach();
assert_eq!(executor.num_tasks(), 3);

// Since we have cancelled the tasks above, the loops should eventually end
while executor.num_tasks() != 0 {
executor.tick();
}
}

#[test]
fn test_tasks_cancellation() {
let executor = TickedAsyncExecutor::new(|_state| println!("{_state:?}"));
let task1 = executor.spawn("A", async move {
loop {
tokio::task::yield_now().await;
}
});

let task2 = executor.spawn(format!("B"), async move {
loop {
tokio::task::yield_now().await;
}
});
assert_eq!(executor.num_tasks(), 2);
executor.tick();

executor
.spawn_local("CancelTasks", async move {
let (t1, t2) = join!(task1.cancel(), task2.cancel());
let (t1, t2) = tokio::join!(task1.cancel(), task2.cancel());
assert_eq!(t1, None);
assert_eq!(t2, None);
})
Expand Down
79 changes: 79 additions & 0 deletions tests/tokio_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use ticked_async_executor::TickedAsyncExecutor;

#[test]
fn test_tokio_join() {
let executor = TickedAsyncExecutor::default();

let (tx1, mut rx1) = tokio::sync::mpsc::channel::<usize>(1);
let (tx2, mut rx2) = tokio::sync::mpsc::channel::<usize>(1);
executor
.spawn("ThreadedFuture", async move {
let (a, b) = tokio::join!(rx1.recv(), rx2.recv());
assert_eq!(a.unwrap(), 10);
assert_eq!(b.unwrap(), 20);
})
.detach();

let (tx3, mut rx3) = tokio::sync::mpsc::channel::<usize>(1);
let (tx4, mut rx4) = tokio::sync::mpsc::channel::<usize>(1);
executor
.spawn("LocalFuture", async move {
let (a, b) = tokio::join!(rx3.recv(), rx4.recv());
assert_eq!(a.unwrap(), 10);
assert_eq!(b.unwrap(), 20);
})
.detach();

tx1.try_send(10).unwrap();
tx3.try_send(10).unwrap();
for _ in 0..10 {
executor.tick();
}
tx2.try_send(20).unwrap();
tx4.try_send(20).unwrap();

while executor.num_tasks() != 0 {
executor.tick();
}
}

#[test]
fn test_tokio_select() {
let executor = TickedAsyncExecutor::default();

let (tx1, mut rx1) = tokio::sync::mpsc::channel::<usize>(1);
let (_tx2, mut rx2) = tokio::sync::mpsc::channel::<usize>(1);
executor
.spawn("ThreadedFuture", async move {
tokio::select! {
data = rx1.recv() => {
assert_eq!(data.unwrap(), 10);
}
_ = rx2.recv() => {}
}
})
.detach();

let (tx3, mut rx3) = tokio::sync::mpsc::channel::<usize>(1);
let (_tx4, mut rx4) = tokio::sync::mpsc::channel::<usize>(1);
executor
.spawn("LocalFuture", async move {
tokio::select! {
data = rx3.recv() => {
assert_eq!(data.unwrap(), 10);
}
_ = rx4.recv() => {}
}
})
.detach();

for _ in 0..10 {
executor.tick();
}

tx1.try_send(10).unwrap();
tx3.try_send(10).unwrap();
while executor.num_tasks() != 0 {
executor.tick();
}
}

0 comments on commit f22d942

Please sign in to comment.