Skip to content

Commit

Permalink
sync: implement Clone for watch::Sender (#6388)
Browse files Browse the repository at this point in the history
  • Loading branch information
mox692 authored Mar 12, 2024
1 parent b4ab647 commit a3d2548
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 4 deletions.
20 changes: 20 additions & 0 deletions tokio/src/sync/tests/loom_watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,23 @@ fn wait_for_returns_correct_value() {
jh.join().unwrap();
});
}

#[test]
fn multiple_sender_drop_concurrently() {
loom::model(move || {
let (tx1, rx) = watch::channel(0);
let tx2 = tx1.clone();

let jh = thread::spawn(move || {
drop(tx2);
});
assert!(rx.has_changed().is_ok());

drop(tx1);

jh.join().unwrap();

// Check if all sender are dropped and closed flag is set.
assert!(rx.has_changed().is_err());
});
}
22 changes: 19 additions & 3 deletions tokio/src/sync/watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@
use crate::sync::notify::Notify;

use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::Ordering::Relaxed;
use crate::loom::sync::atomic::Ordering::{AcqRel, Relaxed};
use crate::loom::sync::{Arc, RwLock, RwLockReadGuard};
use std::fmt;
use std::mem;
Expand Down Expand Up @@ -146,6 +146,16 @@ pub struct Sender<T> {
shared: Arc<Shared<T>>,
}

impl<T> Clone for Sender<T> {
fn clone(&self) -> Self {
self.shared.ref_count_tx.fetch_add(1, Relaxed);

Self {
shared: self.shared.clone(),
}
}
}

/// Returns a reference to the inner value.
///
/// Outstanding borrows hold a read lock on the inner value. This means that
Expand Down Expand Up @@ -238,6 +248,9 @@ struct Shared<T> {
/// Tracks the number of `Receiver` instances.
ref_count_rx: AtomicUsize,

/// Tracks the number of `Sender` instances.
ref_count_tx: AtomicUsize,

/// Notifies waiting receivers that the value changed.
notify_rx: big_notify::BigNotify,

Expand Down Expand Up @@ -485,6 +498,7 @@ pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
value: RwLock::new(init),
state: AtomicState::new(),
ref_count_rx: AtomicUsize::new(1),
ref_count_tx: AtomicUsize::new(1),
notify_rx: big_notify::BigNotify::new(),
notify_tx: Notify::new(),
});
Expand Down Expand Up @@ -1302,8 +1316,10 @@ impl<T> Sender<T> {

impl<T> Drop for Sender<T> {
fn drop(&mut self) {
self.shared.state.set_closed();
self.shared.notify_rx.notify_waiters();
if self.shared.ref_count_tx.fetch_sub(1, AcqRel) == 1 {
self.shared.state.set_closed();
self.shared.notify_rx.notify_waiters();
}
}
}

Expand Down
38 changes: 37 additions & 1 deletion tokio/tests/sync_watch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@ use wasm_bindgen_test::wasm_bindgen_test as test;

use tokio::sync::watch;
use tokio_test::task::spawn;
use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok};
use tokio_test::{
assert_pending, assert_ready, assert_ready_eq, assert_ready_err, assert_ready_ok,
};

#[test]
fn single_rx_recv() {
Expand Down Expand Up @@ -332,3 +334,37 @@ fn send_modify_panic() {
assert_ready_ok!(task.poll());
assert_eq!(*rx.borrow_and_update(), "three");
}

#[tokio::test]
async fn multiple_sender() {
let (tx1, mut rx) = watch::channel(0);
let tx2 = tx1.clone();

let mut t = spawn(async {
rx.changed().await.unwrap();
let v1 = *rx.borrow_and_update();
rx.changed().await.unwrap();
let v2 = *rx.borrow_and_update();
(v1, v2)
});

tx1.send(1).unwrap();
assert_pending!(t.poll());
tx2.send(2).unwrap();
assert_ready_eq!(t.poll(), (1, 2));
}

#[tokio::test]
async fn reciever_is_notified_when_last_sender_is_dropped() {
let (tx1, mut rx) = watch::channel(0);
let tx2 = tx1.clone();

let mut t = spawn(rx.changed());
assert_pending!(t.poll());

drop(tx1);
assert!(!t.is_woken());
drop(tx2);

assert!(t.is_woken());
}

0 comments on commit a3d2548

Please sign in to comment.