Skip to content

Drop all messages in bounded channel when destroying the last receiver #108164

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 83 additions & 24 deletions library/std/src/sync/mpmc/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ struct Slot<T> {
/// The current stamp.
stamp: AtomicUsize,

/// The message in this slot.
/// The message in this slot. Either read out in `read` or dropped through
/// `discard_all_messages`.
msg: UnsafeCell<MaybeUninit<T>>,
}

Expand Down Expand Up @@ -439,21 +440,99 @@ impl<T> Channel<T> {
Some(self.cap)
}

/// Disconnects the channel and wakes up all blocked senders and receivers.
/// Disconnects senders and wakes up all blocked receivers.
///
/// Returns `true` if this call disconnected the channel.
pub(crate) fn disconnect(&self) -> bool {
pub(crate) fn disconnect_senders(&self) -> bool {
let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);

if tail & self.mark_bit == 0 {
self.senders.disconnect();
self.receivers.disconnect();
true
} else {
false
}
}

/// Disconnects receivers and wakes up all blocked senders.
///
/// Returns `true` if this call disconnected the channel.
///
/// # Safety
/// May only be called once upon dropping the last receiver. The
/// destruction of all other receivers must have been observed with acquire
/// ordering or stronger.
pub(crate) unsafe fn disconnect_receivers(&self) -> bool {
let tail = self.tail.fetch_or(self.mark_bit, Ordering::SeqCst);
let disconnected = if tail & self.mark_bit == 0 {
self.senders.disconnect();
true
} else {
false
};

self.discard_all_messages(tail);
disconnected
}

/// Discards all messages.
///
/// `tail` should be the current (and therefore last) value of `tail`.
///
/// # Panicking
/// If a destructor panics, the remaining messages are leaked, matching the
/// behaviour of the unbounded channel.
///
/// # Safety
/// This method must only be called when dropping the last receiver. The
/// destruction of all other receivers must have been observed with acquire
/// ordering or stronger.
unsafe fn discard_all_messages(&self, tail: usize) {
debug_assert!(self.is_disconnected());

// Only receivers modify `head`, so since we are the last one,
// this value will not change and will not be observed (since
// no new messages can be sent after disconnection).
let mut head = self.head.load(Ordering::Relaxed);
let tail = tail & !self.mark_bit;

let backoff = Backoff::new();
loop {
// Deconstruct the head.
let index = head & (self.mark_bit - 1);
let lap = head & !(self.one_lap - 1);

// Inspect the corresponding slot.
debug_assert!(index < self.buffer.len());
let slot = unsafe { self.buffer.get_unchecked(index) };
let stamp = slot.stamp.load(Ordering::Acquire);

// If the stamp is ahead of the head by 1, we may drop the message.
if head + 1 == stamp {
head = if index + 1 < self.cap {
// Same lap, incremented index.
// Set to `{ lap: lap, mark: 0, index: index + 1 }`.
head + 1
} else {
// One lap forward, index wraps around to zero.
// Set to `{ lap: lap.wrapping_add(1), mark: 0, index: 0 }`.
lap.wrapping_add(self.one_lap)
};

unsafe {
(*slot.msg.get()).assume_init_drop();
}
// If the tail equals the head, that means the channel is empty.
} else if tail == head {
return;
// Otherwise, a sender is about to write into the slot, so we need
// to wait for it to update the stamp.
} else {
backoff.spin_heavy();
}
}
}

/// Returns `true` if the channel is disconnected.
pub(crate) fn is_disconnected(&self) -> bool {
self.tail.load(Ordering::SeqCst) & self.mark_bit != 0
Expand Down Expand Up @@ -483,23 +562,3 @@ impl<T> Channel<T> {
head.wrapping_add(self.one_lap) == tail & !self.mark_bit
}
}

impl<T> Drop for Channel<T> {
fn drop(&mut self) {
// Get the index of the head.
let hix = self.head.load(Ordering::Relaxed) & (self.mark_bit - 1);

// Loop over all slots that hold a message and drop them.
for i in 0..self.len() {
// Compute the index of the next slot holding a message.
let index = if hix + i < self.cap { hix + i } else { hix + i - self.cap };

unsafe {
debug_assert!(index < self.buffer.len());
let slot = self.buffer.get_unchecked_mut(index);
let msg = &mut *slot.msg.get();
msg.as_mut_ptr().drop_in_place();
}
}
}
}
4 changes: 2 additions & 2 deletions library/std/src/sync/mpmc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ impl<T> Drop for Sender<T> {
fn drop(&mut self) {
unsafe {
match &self.flavor {
SenderFlavor::Array(chan) => chan.release(|c| c.disconnect()),
SenderFlavor::Array(chan) => chan.release(|c| c.disconnect_senders()),
SenderFlavor::List(chan) => chan.release(|c| c.disconnect_senders()),
SenderFlavor::Zero(chan) => chan.release(|c| c.disconnect()),
}
Expand Down Expand Up @@ -403,7 +403,7 @@ impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
unsafe {
match &self.flavor {
ReceiverFlavor::Array(chan) => chan.release(|c| c.disconnect()),
ReceiverFlavor::Array(chan) => chan.release(|c| c.disconnect_receivers()),
ReceiverFlavor::List(chan) => chan.release(|c| c.disconnect_receivers()),
ReceiverFlavor::Zero(chan) => chan.release(|c| c.disconnect()),
}
Expand Down
13 changes: 13 additions & 0 deletions library/std/src/sync/mpsc/sync_tests.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use super::*;
use crate::env;
use crate::rc::Rc;
use crate::sync::mpmc::SendTimeoutError;
use crate::thread;
use crate::time::Duration;
Expand Down Expand Up @@ -656,3 +657,15 @@ fn issue_15761() {
repro()
}
}

#[test]
fn drop_unreceived() {
let (tx, rx) = sync_channel::<Rc<()>>(1);
let msg = Rc::new(());
let weak = Rc::downgrade(&msg);
assert!(tx.send(msg).is_ok());
drop(rx);
// Messages should be dropped immediately when the last receiver is destroyed.
assert!(weak.upgrade().is_none());
drop(tx);
}