Skip to content

Commit

Permalink
address review
Browse files Browse the repository at this point in the history
  • Loading branch information
b-naber committed Jul 18, 2022
1 parent 5006259 commit 4122c9a
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 228 deletions.
222 changes: 2 additions & 220 deletions tokio-util/tests/mpsc.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
use futures::future::poll_fn;
use std::ops::Drop;
use std::sync::atomic::{
AtomicUsize,
Ordering::{Acquire, Release},
};
use std::time::Duration;
use tokio::join;
use tokio::sync::mpsc::{self, channel};
use tokio::sync::oneshot;
use tokio::time;

use tokio::sync::mpsc::channel;
use tokio_test::task::spawn;
use tokio_test::{assert_pending, assert_ready, assert_ready_err, assert_ready_ok};
use tokio_util::sync::PollSender;
Expand Down Expand Up @@ -246,213 +238,3 @@ fn start_send_panics_when_acquiring() {
assert_pending!(reserve.poll());
send.send_item(2).unwrap();
}

#[tokio::test]
async fn weak_sender() {
let (tx, mut rx) = channel(11);
let tx_weak = tx.clone().downgrade();

let tx_weak = tokio::spawn(async move {
for i in 0..10 {
if tx.send(i).await.is_err() {
return None;
}
}

let tx2 = tx_weak
.upgrade()
.expect("expected to be able to upgrade tx_weak");
let _ = tx2.send(20).await;
let tx_weak = tx2.downgrade();

Some(tx_weak)
})
.await
.unwrap();

for i in 0..12 {
let recvd = rx.recv().await;

match recvd {
Some(msg) => {
if i == 10 {
assert_eq!(msg, 20);
}
}
None => {
assert_eq!(i, 11);
break;
}
}
}

if let Some(tx_weak) = tx_weak {
let upgraded = tx_weak.upgrade();
assert!(upgraded.is_none());
}
}

#[tokio::test]
async fn actor_weak_sender() {
pub struct MyActor {
receiver: mpsc::Receiver<ActorMessage>,
sender: mpsc::WeakSender<ActorMessage>,
next_id: u32,
pub received_self_msg: bool,
}

enum ActorMessage {
GetUniqueId { respond_to: oneshot::Sender<u32> },
SelfMessage {},
}

impl MyActor {
fn new(
receiver: mpsc::Receiver<ActorMessage>,
sender: mpsc::WeakSender<ActorMessage>,
) -> Self {
MyActor {
receiver,
sender,
next_id: 0,
received_self_msg: false,
}
}

fn handle_message(&mut self, msg: ActorMessage) {
match msg {
ActorMessage::GetUniqueId { respond_to } => {
self.next_id += 1;

// The `let _ =` ignores any errors when sending.
//
// This can happen if the `select!` macro is used
// to cancel waiting for the response.
let _ = respond_to.send(self.next_id);
}
ActorMessage::SelfMessage { .. } => {
self.received_self_msg = true;
}
}
}

async fn send_message_to_self(&mut self) {
let msg = ActorMessage::SelfMessage {};

let sender = self.sender.clone();

// cannot move self.sender here
if let Some(sender) = sender.upgrade() {
let _ = sender.send(msg).await;
self.sender = sender.downgrade();
}
}

async fn run(&mut self) {
let mut i = 0;
while let Some(msg) = self.receiver.recv().await {
self.handle_message(msg);

if i == 0 {
self.send_message_to_self().await;
}

i += 1
}

assert!(self.received_self_msg);
}
}

#[derive(Clone)]
pub struct MyActorHandle {
sender: mpsc::Sender<ActorMessage>,
}

impl MyActorHandle {
pub fn new() -> (Self, MyActor) {
let (sender, receiver) = mpsc::channel(8);
let actor = MyActor::new(receiver, sender.clone().downgrade());

(Self { sender }, actor)
}

pub async fn get_unique_id(&self) -> u32 {
let (send, recv) = oneshot::channel();
let msg = ActorMessage::GetUniqueId { respond_to: send };

// Ignore send errors. If this send fails, so does the
// recv.await below. There's no reason to check the
// failure twice.
let _ = self.sender.send(msg).await;
recv.await.expect("Actor task has been killed")
}
}

let (handle, mut actor) = MyActorHandle::new();

let actor_handle = tokio::spawn(async move { actor.run().await });

let _ = tokio::spawn(async move {
let _ = handle.get_unique_id().await;
drop(handle);
})
.await;

let _ = actor_handle.await;
}

static NUM_DROPPED: AtomicUsize = AtomicUsize::new(0);

#[derive(Debug)]
struct Msg;

impl Drop for Msg {
fn drop(&mut self) {
NUM_DROPPED.fetch_add(1, Release);
}
}

// Tests that no pending messages are put onto the channel after `Rx` was
// dropped.
//
// Note: After the introduction of `WeakSender`, which internally
// used `Arc` and doesn't call a drop of the channel after the last strong
// `Sender` was dropped while more than one `WeakSender` remains, we want to
// ensure that no messages are kept in the channel, which were sent after
// the receiver was dropped.
#[tokio::test(start_paused = true)]
async fn test_msgs_dropped_on_rx_drop() {
fn ms(millis: u64) -> Duration {
Duration::from_millis(millis)
}

let (tx, mut rx) = mpsc::channel(3);

let rx_handle = tokio::spawn(async move {
let _ = rx.recv().await.unwrap();
let _ = rx.recv().await.unwrap();
time::sleep(ms(1)).await;
drop(rx);

time::advance(ms(1)).await;
});

let tx_handle = tokio::spawn(async move {
let _ = tx.send(Msg {}).await.unwrap();
let _ = tx.send(Msg {}).await.unwrap();

// This msg will be pending and should be dropped when `rx` is dropped
let _ = tx.send(Msg {}).await.unwrap();
time::advance(ms(1)).await;

// This msg will not be put onto `Tx` list anymore, since `Rx` is closed.
time::sleep(ms(1)).await;
let _ = tx.send(Msg {}).await.unwrap();

// Ensure that third message isn't put onto the channel anymore
assert_eq!(NUM_DROPPED.load(Acquire), 4);
});

let (_, _) = join!(rx_handle, tx_handle);
}
2 changes: 1 addition & 1 deletion tokio/src/sync/mpsc/bounded.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,7 @@ impl<T> WeakSender<T> {
/// if there are other `Sender` instances alive and the channel wasn't
/// previously dropped, otherwise `None` is returned.
pub fn upgrade(self) -> Option<Sender<T>> {
self.chan.upgrade().map(|tx| Sender::new(tx))
self.chan.upgrade().map(Sender::new)
}
}

Expand Down
12 changes: 6 additions & 6 deletions tokio/src/sync/mpsc/chan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,12 @@ impl<T, S> Tx<T, S> {
pub(super) fn upgrade(self) -> Option<Self> {
let mut tx_count = self.inner.tx_count.load(Acquire);

if tx_count == 0 {
// channel is closed
return None;
}

loop {
if tx_count == 0 {
// channel is closed
return None;
}

match self
.inner
.tx_count
Expand All @@ -164,7 +164,7 @@ impl<T, S> Tx<T, S> {
}
Err(prev_count) => {
if prev_count == 0 {
return None;
unreachable!();
}

tx_count = prev_count;
Expand Down
Loading

0 comments on commit 4122c9a

Please sign in to comment.