Skip to content

Commit

Permalink
feat(maitake): addClone/Drop for TaskRef (#189)
Browse files Browse the repository at this point in the history
these increment/decrement the task's reference count and can drop the
task if the dropped `TaskRef` is the last ref. this required some changes
to reference counting via poll state transitions.

will be needed for #184
depends on #188

Signed-off-by: Eliza Weisman <eliza@buoyant.io>
  • Loading branch information
hawkw authored Jun 3, 2022
1 parent 0029a1b commit bde8172
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 48 deletions.
119 changes: 72 additions & 47 deletions maitake/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ pub use core::task::{Context, Poll, Waker};
mod state;
mod storage;

#[cfg(test)]
mod tests;

use crate::{
loom::cell::UnsafeCell,
scheduler::Schedule,
Expand Down Expand Up @@ -160,11 +163,9 @@ enum Cell<F: Future> {
#[derive(Debug)]
struct Vtable {
/// Poll the future.
poll: unsafe fn(NonNull<Header>) -> Poll<()>,
/* // TODO(eliza): this will be needed when tasks can be dropped through `JoinHandle` refs...
poll: unsafe fn(TaskRef) -> Poll<()>,
/// Drops the task and deallocates its memory.
deallocate: unsafe fn(NonNull<Header>),
*/
}

// === impl Task ===
Expand All @@ -187,7 +188,7 @@ where
{
const TASK_VTABLE: Vtable = Vtable {
poll: Self::poll,
// deallocate: Self::deallocate,
deallocate: Self::deallocate,
};

const WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
Expand Down Expand Up @@ -252,7 +253,7 @@ where
// transition does *not* decrement the reference count. this is
// in order to avoid dropping the task while it is being
// scheduled. one reference is consumed by enqueuing the task...
Self::schedule(this);
Self::schedule(TaskRef(this.cast::<Header>()));
// now that the task has been enqueued, decrement the reference
// count to drop the waker that performed the `wake_by_val`.
Self::drop_ref(this);
Expand All @@ -266,15 +267,13 @@ where

let this = non_null(ptr as *mut ()).cast::<Self>();
if this.as_ref().state().wake_by_ref() == ScheduleAction::Enqueue {
Self::schedule(this);
Self::schedule(TaskRef(this.cast::<Header>()));
}
}

#[inline(always)]
unsafe fn schedule(this: NonNull<Self>) {
this.as_ref()
.scheduler
.schedule(TaskRef(this.cast::<Header>()));
unsafe fn schedule(this: TaskRef) {
this.0.cast::<Self>().as_ref().scheduler.schedule(this);
}

#[inline]
Expand All @@ -287,9 +286,9 @@ where
drop(STO::from_raw(this))
}

unsafe fn poll(ptr: NonNull<Header>) -> Poll<()> {
unsafe fn poll(ptr: TaskRef) -> Poll<()> {
trace_task!(ptr, F, "poll");
let mut this = ptr.cast::<Self>();
let mut this = ptr.0.cast::<Self>();
test_trace!(task = ?fmt::alt(this.as_ref()));
// try to transition the task to the polling state
let state = &this.as_ref().state();
Expand Down Expand Up @@ -317,17 +316,18 @@ where
// post-poll state transition
match test_dbg!(state.end_poll(poll.is_ready())) {
OrDrop::Drop => drop(STO::from_raw(this)),
OrDrop::Action(ScheduleAction::Enqueue) => Self::schedule(this),
OrDrop::Action(ScheduleAction::Enqueue) => Self::schedule(ptr),
OrDrop::Action(ScheduleAction::None) => {}
}

poll
}

// unsafe fn deallocate(ptr: NonNull<Header>) {
// trace_task!(ptr, F, "deallocate");
// drop(Box::from_raw(ptr.cast::<Self>().as_ptr()))
// }
unsafe fn deallocate(ptr: NonNull<Header>) {
trace_task!(ptr, F, "deallocate");
let this = ptr.cast::<Self>();
drop(STO::from_raw(this));
}

fn poll_inner(&self, mut cx: Context<'_>) -> Poll<()> {
self.inner.with_mut(|cell| {
Expand Down Expand Up @@ -395,42 +395,42 @@ impl TaskRef {
Self(ptr)
}

pub(crate) fn poll(&self) -> Poll<()> {
pub(crate) fn poll(self) -> Poll<()> {
let poll_fn = self.header().vtable.poll;
unsafe { poll_fn(self.0) }
unsafe { poll_fn(self) }
}

// #[inline]
// fn state(&self) -> &StateVar {
// &self.header().state
// }
#[inline]
fn state(&self) -> &StateCell {
&self.header().state
}

#[inline]
fn header(&self) -> &Header {
unsafe { self.0.as_ref() }
}
}

// impl Clone for TaskRef {
// #[inline]
// fn clone(&self) -> Self {
// self.state().clone_ref();
// Self(self.0)
// }
// }

// impl Drop for TaskRef {
// #[inline]
// fn drop(&mut self) {
// if !self.state().drop_ref() {
// return;
// }

// unsafe {
// Header::drop_slow(self.0);
// }
// }
// }
impl Clone for TaskRef {
#[inline]
fn clone(&self) -> Self {
self.state().clone_ref();
Self(self.0)
}
}

impl Drop for TaskRef {
#[inline]
fn drop(&mut self) {
if !self.state().drop_ref() {
return;
}

unsafe {
Header::deallocate(self.0);
}
}
}

unsafe impl Send for TaskRef {}
unsafe impl Sync for TaskRef {}
Expand All @@ -440,19 +440,39 @@ unsafe impl Sync for TaskRef {}
impl Header {
#[cfg(not(loom))]
pub(crate) const fn new_stub() -> Self {
unsafe fn nop(_ptr: NonNull<Header>) -> Poll<()> {
unsafe fn nop(_ptr: TaskRef) -> Poll<()> {
#[cfg(debug_assertions)]
unreachable!("stub task ({_ptr:p}) should never be polled!");
unreachable!("stub task ({_ptr:?}) should never be polled!");
#[cfg(not(debug_assertions))]
Poll::Pending
}

unsafe fn nop_deallocate(ptr: NonNull<Header>) {
unreachable!("stub task ({ptr:p}) should never be deallocated!");
}

Self {
run_queue: mpsc_queue::Links::new_stub(),
state: StateCell::new(),
vtable: &Vtable { poll: nop },
vtable: &Vtable {
poll: nop,
deallocate: nop_deallocate,
},
}
}

unsafe fn deallocate(this: NonNull<Self>) {
#[cfg(debug_assertions)]
let refs = this
.as_ref()
.state
.load(core::sync::atomic::Ordering::Acquire)
.ref_count();
debug_assert_eq!(refs, 0, "tried to deallocate a task with references!");

let deallocate = this.as_ref().vtable.deallocate;
deallocate(this)
}
}

/// # Safety
Expand All @@ -462,7 +482,12 @@ unsafe impl Linked<mpsc_queue::Links<Header>> for Header {
type Handle = TaskRef;

fn into_ptr(task: Self::Handle) -> NonNull<Self> {
task.0
let ptr = task.0;
// converting a `TaskRef` into a pointer to enqueue it assigns ownership
// of the ref count to the queue, so we don't want to run its `Drop`
// impl.
mem::forget(task);
ptr
}

/// Convert a raw pointer to a `Handle`.
Expand Down
1 change: 0 additions & 1 deletion maitake/src/task/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,6 @@ impl StateCell {
return OrDrop::Action(ScheduleAction::Enqueue);
}

let next_state = test_dbg!(next_state.drop_ref());
*state = next_state;

if next_state.ref_count() == 0 {
Expand Down
55 changes: 55 additions & 0 deletions maitake/src/task/tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#[cfg(loom)]
mod loom {
use crate::loom::{self, alloc::Track};
use crate::task::*;

#[derive(Clone)]
struct NopScheduler;

impl crate::scheduler::Schedule for NopScheduler {
fn schedule(&self, task: TaskRef) {
unimplemented!(
"nop scheduler should not actually schedule tasks (tried to schedule {task:?})"
)
}
}

#[test]
fn taskref_deallocates() {
loom::model(|| {
let track = Track::new(());
let task = TaskRef::new(NopScheduler, async move {
drop(track);
});

// if the task is not deallocated by dropping the `TaskRef`, the
// `Track` will be leaked.
drop(task);
});
}

#[test]
fn taskref_clones_deallocate() {
loom::model(|| {
let track = Track::new(());
let task = TaskRef::new(NopScheduler, async move {
drop(track);
});

let mut threads = (0..2)
.map(|_| {
let task = task.clone();
loom::thread::spawn(move || {
drop(task);
})
})
.collect::<Vec<_>>();

drop(task);

for thread in threads.drain(..) {
thread.join().unwrap();
}
});
}
}

0 comments on commit bde8172

Please sign in to comment.