Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix memory leak if tasks contain wakers #31

Merged
merged 1 commit into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ pub(crate) struct Header {

impl Header {
/// Construct a new waker.
pub(crate) fn new(shared: Arc<Shared>, index: usize) -> Self {
fn new(shared: Arc<Shared>, index: usize) -> Self {
Self {
shared,
index,
Expand Down Expand Up @@ -200,15 +200,9 @@ impl<T> Storage<T> {
None => return false,
};

// SAFETY: We have mutable access to the given entry, but we are careful
// not to dereference the header mutably, since that might be shared.
unsafe {
let value = match *ptr::addr_of_mut!((*task.as_ptr()).entry) {
ref mut value @ Entry::Some(..) => value,
_ => return false,
};

*value = Entry::Vacant(self.next);
// SAFETY: The `task` pointer is valid, since we got it from the slab.
if !unsafe { make_slot_vacant(task, self.next) } {
return false;
}

self.len -= 1;
Expand All @@ -221,11 +215,18 @@ impl<T> Storage<T> {
// SAFETY: We're just decrementing the reference count of each entry
// before dropping the storage of the slab.
unsafe {
for &entry in &self.tasks {
if entry.as_ref().header.decrement_ref() {
for &task in &self.tasks {
// We must drop a task's entry _before_ decrementing the reference counter
// because the task can be accessed by wakers in parallel now.
//
// Also, we violate the linked list of vacant slots by passing `0` here
// because the whole `tasks` vector will be cleared below anyway.
make_slot_vacant(task, 0);

if task.as_ref().header.decrement_ref() {
// SAFETY: We're the only ones holding a reference to the
// entry, so it's safe to drop it.
_ = Box::from_raw(entry.as_ptr());
// task, so it's safe to drop it.
_ = Box::from_raw(task.as_ptr());
}
}

Expand Down Expand Up @@ -263,6 +264,24 @@ impl<T> Storage<T> {
}
}

/// Returns `true` if the entry was removed, `false` otherwise.
///
/// # Safety
/// * The `task` pointer must point to a valid entry.
/// * A task's entry must be accessed only by one thread.
unsafe fn make_slot_vacant<T>(task: NonNull<Task<T>>, next: usize) -> bool {
// SAFETY: We have mutable access to the given entry, but we are careful
// not to dereference the header mutably, since that might be shared.
let entry = unsafe { &mut *ptr::addr_of_mut!((*task.as_ptr()).entry) };

if !matches!(entry, Entry::Some(_)) {
return false;
}

*entry = Entry::Vacant(next);
true
}

impl<T> Default for Storage<T> {
fn default() -> Self {
Self::new()
Expand Down
49 changes: 49 additions & 0 deletions tests/stream_test.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
#![cfg(feature = "futures-rs")]

use std::{
pin::Pin,
sync::{atomic, Arc},
task,
};

use tokio_stream::iter;
use unicycle::StreamsUnordered;

Expand All @@ -19,3 +25,46 @@ async fn test_unicycle_streams() {

assert_eq!(vec![5, 1, 6, 2, 7, 3, 8, 4], received);
}

// See #30 for details.
#[tokio::test]
async fn test_drop_with_stored_waker() {
struct Testee {
waker: Option<task::Waker>,
dropped: Arc<atomic::AtomicBool>,
}

impl futures::Stream for Testee {
type Item = u32;

fn poll_next(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Option<u32>> {
println!("testee polled");
unsafe { self.get_unchecked_mut() }.waker = Some(cx.waker().clone());
task::Poll::Pending
}
}

impl Drop for Testee {
fn drop(&mut self) {
println!("testee dropped");
self.dropped.store(true, atomic::Ordering::SeqCst);
}
}

let mut streams = StreamsUnordered::new();

let dropped = Arc::new(atomic::AtomicBool::new(false));
streams.push(Testee {
waker: None,
dropped: dropped.clone(),
});

{
let fut = streams.next();
let res = futures::future::poll_immediate(fut).await;
assert!(res.is_none());
}

drop(streams);
assert!(dropped.load(atomic::Ordering::SeqCst));
}