Skip to content

Commit

Permalink
sync: expose strong and weak counts of mpsc sender handles (#6405)
Browse files Browse the repository at this point in the history
  • Loading branch information
maminrayej authored Mar 22, 2024
1 parent baad270 commit 1846483
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 0 deletions.
28 changes: 28 additions & 0 deletions tokio/src/sync/mpsc/bounded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,16 @@ impl<T> Sender<T> {
pub fn max_capacity(&self) -> usize {
self.chan.semaphore().bound
}

/// Returns the number of [`Sender`] handles.
pub fn strong_count(&self) -> usize {
self.chan.strong_count()
}

/// Returns the number of [`WeakSender`] handles.
pub fn weak_count(&self) -> usize {
self.chan.weak_count()
}
}

impl<T> Clone for Sender<T> {
Expand All @@ -1429,19 +1439,37 @@ impl<T> fmt::Debug for Sender<T> {

impl<T> Clone for WeakSender<T> {
fn clone(&self) -> Self {
self.chan.increment_weak_count();

WeakSender {
chan: self.chan.clone(),
}
}
}

impl<T> Drop for WeakSender<T> {
fn drop(&mut self) {
self.chan.decrement_weak_count();
}
}

impl<T> WeakSender<T> {
/// Tries to convert a `WeakSender` into a [`Sender`]. This will return `Some`
/// if there are other `Sender` instances alive and the channel wasn't
/// previously dropped, otherwise `None` is returned.
pub fn upgrade(&self) -> Option<Sender<T>> {
chan::Tx::upgrade(self.chan.clone()).map(Sender::new)
}

/// Returns the number of [`Sender`] handles.
pub fn strong_count(&self) -> usize {
self.chan.strong_count()
}

/// Returns the number of [`WeakSender`] handles.
pub fn weak_count(&self) -> usize {
self.chan.weak_count()
}
}

impl<T> fmt::Debug for WeakSender<T> {
Expand Down
30 changes: 30 additions & 0 deletions tokio/src/sync/mpsc/chan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ pub(super) struct Chan<T, S> {
/// When this drops to zero, the send half of the channel is closed.
tx_count: AtomicUsize,

/// Tracks the number of outstanding weak sender handles.
tx_weak_count: AtomicUsize,

/// Only accessed by `Rx` handle.
rx_fields: UnsafeCell<RxFields<T>>,
}
Expand Down Expand Up @@ -115,6 +118,7 @@ pub(crate) fn channel<T, S: Semaphore>(semaphore: S) -> (Tx<T, S>, Rx<T, S>) {
semaphore,
rx_waker: CachePadded::new(AtomicWaker::new()),
tx_count: AtomicUsize::new(1),
tx_weak_count: AtomicUsize::new(0),
rx_fields: UnsafeCell::new(RxFields {
list: rx,
rx_closed: false,
Expand All @@ -131,7 +135,17 @@ impl<T, S> Tx<T, S> {
Tx { inner: chan }
}

pub(super) fn strong_count(&self) -> usize {
self.inner.tx_count.load(Acquire)
}

pub(super) fn weak_count(&self) -> usize {
self.inner.tx_weak_count.load(Relaxed)
}

pub(super) fn downgrade(&self) -> Arc<Chan<T, S>> {
self.inner.increment_weak_count();

self.inner.clone()
}

Expand Down Expand Up @@ -452,6 +466,22 @@ impl<T, S> Chan<T, S> {
// Notify the rx task
self.rx_waker.wake();
}

pub(super) fn decrement_weak_count(&self) {
self.tx_weak_count.fetch_sub(1, Relaxed);
}

pub(super) fn increment_weak_count(&self) {
self.tx_weak_count.fetch_add(1, Relaxed);
}

pub(super) fn strong_count(&self) -> usize {
self.tx_count.load(Acquire)
}

pub(super) fn weak_count(&self) -> usize {
self.tx_weak_count.load(Relaxed)
}
}

impl<T, S> Drop for Chan<T, S> {
Expand Down
28 changes: 28 additions & 0 deletions tokio/src/sync/mpsc/unbounded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -578,23 +578,51 @@ impl<T> UnboundedSender<T> {
chan: self.chan.downgrade(),
}
}

/// Returns the number of [`UnboundedSender`] handles.
pub fn strong_count(&self) -> usize {
self.chan.strong_count()
}

/// Returns the number of [`WeakUnboundedSender`] handles.
pub fn weak_count(&self) -> usize {
self.chan.weak_count()
}
}

impl<T> Clone for WeakUnboundedSender<T> {
fn clone(&self) -> Self {
self.chan.increment_weak_count();

WeakUnboundedSender {
chan: self.chan.clone(),
}
}
}

impl<T> Drop for WeakUnboundedSender<T> {
fn drop(&mut self) {
self.chan.decrement_weak_count();
}
}

impl<T> WeakUnboundedSender<T> {
/// Tries to convert a `WeakUnboundedSender` into an [`UnboundedSender`].
/// This will return `Some` if there are other `Sender` instances alive and
/// the channel wasn't previously dropped, otherwise `None` is returned.
pub fn upgrade(&self) -> Option<UnboundedSender<T>> {
chan::Tx::upgrade(self.chan.clone()).map(UnboundedSender::new)
}

/// Returns the number of [`UnboundedSender`] handles.
pub fn strong_count(&self) -> usize {
self.chan.strong_count()
}

/// Returns the number of [`WeakUnboundedSender`] handles.
pub fn weak_count(&self) -> usize {
self.chan.weak_count()
}
}

impl<T> fmt::Debug for WeakUnboundedSender<T> {
Expand Down
142 changes: 142 additions & 0 deletions tokio/tests/sync_mpsc_weak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -511,3 +511,145 @@ fn test_tx_count_weak_unbounded_sender() {

assert!(tx_weak.upgrade().is_none() && tx_weak2.upgrade().is_none());
}

#[tokio::test]
async fn sender_strong_count_when_cloned() {
let (tx, _rx) = mpsc::channel::<()>(1);

let tx2 = tx.clone();

assert_eq!(tx.strong_count(), 2);
assert_eq!(tx2.strong_count(), 2);
}

#[tokio::test]
async fn sender_weak_count_when_downgraded() {
let (tx, _rx) = mpsc::channel::<()>(1);

let weak = tx.downgrade();

assert_eq!(tx.weak_count(), 1);
assert_eq!(weak.weak_count(), 1);
}

#[tokio::test]
async fn sender_strong_count_when_dropped() {
let (tx, _rx) = mpsc::channel::<()>(1);

let tx2 = tx.clone();

drop(tx2);

assert_eq!(tx.strong_count(), 1);
}

#[tokio::test]
async fn sender_weak_count_when_dropped() {
let (tx, _rx) = mpsc::channel::<()>(1);

let weak = tx.downgrade();

drop(weak);

assert_eq!(tx.weak_count(), 0);
}

#[tokio::test]
async fn sender_strong_and_weak_conut() {
let (tx, _rx) = mpsc::channel::<()>(1);

let tx2 = tx.clone();

let weak = tx.downgrade();
let weak2 = tx2.downgrade();

assert_eq!(tx.strong_count(), 2);
assert_eq!(tx2.strong_count(), 2);
assert_eq!(weak.strong_count(), 2);
assert_eq!(weak2.strong_count(), 2);

assert_eq!(tx.weak_count(), 2);
assert_eq!(tx2.weak_count(), 2);
assert_eq!(weak.weak_count(), 2);
assert_eq!(weak2.weak_count(), 2);

drop(tx2);
drop(weak2);

assert_eq!(tx.strong_count(), 1);
assert_eq!(weak.strong_count(), 1);

assert_eq!(tx.weak_count(), 1);
assert_eq!(weak.weak_count(), 1);
}

#[tokio::test]
async fn unbounded_sender_strong_count_when_cloned() {
let (tx, _rx) = mpsc::unbounded_channel::<()>();

let tx2 = tx.clone();

assert_eq!(tx.strong_count(), 2);
assert_eq!(tx2.strong_count(), 2);
}

#[tokio::test]
async fn unbounded_sender_weak_count_when_downgraded() {
let (tx, _rx) = mpsc::unbounded_channel::<()>();

let weak = tx.downgrade();

assert_eq!(tx.weak_count(), 1);
assert_eq!(weak.weak_count(), 1);
}

#[tokio::test]
async fn unbounded_sender_strong_count_when_dropped() {
let (tx, _rx) = mpsc::unbounded_channel::<()>();

let tx2 = tx.clone();

drop(tx2);

assert_eq!(tx.strong_count(), 1);
}

#[tokio::test]
async fn unbounded_sender_weak_count_when_dropped() {
let (tx, _rx) = mpsc::unbounded_channel::<()>();

let weak = tx.downgrade();

drop(weak);

assert_eq!(tx.weak_count(), 0);
}

#[tokio::test]
async fn unbounded_sender_strong_and_weak_conut() {
let (tx, _rx) = mpsc::unbounded_channel::<()>();

let tx2 = tx.clone();

let weak = tx.downgrade();
let weak2 = tx2.downgrade();

assert_eq!(tx.strong_count(), 2);
assert_eq!(tx2.strong_count(), 2);
assert_eq!(weak.strong_count(), 2);
assert_eq!(weak2.strong_count(), 2);

assert_eq!(tx.weak_count(), 2);
assert_eq!(tx2.weak_count(), 2);
assert_eq!(weak.weak_count(), 2);
assert_eq!(weak2.weak_count(), 2);

drop(tx2);
drop(weak2);

assert_eq!(tx.strong_count(), 1);
assert_eq!(weak.strong_count(), 1);

assert_eq!(tx.weak_count(), 1);
assert_eq!(weak.weak_count(), 1);
}

0 comments on commit 1846483

Please sign in to comment.