Skip to content

Commit f7654fe

Browse files
authored
Fix race with dropping mpsc::Receiver (#2304)
1 parent ddbf522 commit f7654fe

File tree

2 files changed

+194
-17
lines changed

2 files changed

+194
-17
lines changed

futures-channel/src/mpsc/mod.rs

+59-16
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ use std::pin::Pin;
8686
use std::sync::{Arc, Mutex};
8787
use std::sync::atomic::AtomicUsize;
8888
use std::sync::atomic::Ordering::SeqCst;
89+
use std::thread;
8990

9091
use crate::mpsc::queue::Queue;
9192

@@ -1047,7 +1048,12 @@ impl<T> Receiver<T> {
10471048
}
10481049
None => {
10491050
let state = decode_state(inner.state.load(SeqCst));
1050-
if state.is_open || state.num_messages != 0 {
1051+
if state.is_closed() {
1052+
// If closed flag is set AND there are no pending messages
1053+
// it means end of stream
1054+
self.inner = None;
1055+
Poll::Ready(None)
1056+
} else {
10511057
// If queue is open, we need to return Pending
10521058
// to be woken up when new messages arrive.
10531059
// If queue is closed but num_messages is non-zero,
@@ -1056,11 +1062,6 @@ impl<T> Receiver<T> {
10561062
// so we need to park until sender unparks the task
10571063
// after queueing the message.
10581064
Poll::Pending
1059-
} else {
1060-
// If closed flag is set AND there are no pending messages
1061-
// it means end of stream
1062-
self.inner = None;
1063-
Poll::Ready(None)
10641065
}
10651066
}
10661067
}
@@ -1126,8 +1127,26 @@ impl<T> Drop for Receiver<T> {
11261127
// Drain the channel of all pending messages
11271128
self.close();
11281129
if self.inner.is_some() {
1129-
while let Poll::Ready(Some(..)) = self.next_message() {
1130-
// ...
1130+
loop {
1131+
match self.next_message() {
1132+
Poll::Ready(Some(_)) => {}
1133+
Poll::Ready(None) => break,
1134+
Poll::Pending => {
1135+
let state = decode_state(self.inner.as_ref().unwrap().state.load(SeqCst));
1136+
1137+
// If the channel is closed, then there is no need to park.
1138+
if state.is_closed() {
1139+
break;
1140+
}
1141+
1142+
// TODO: Spinning isn't ideal, it might be worth
1143+
// investigating using a condvar or some other strategy
1144+
// here. That said, if this case is hit, then another thread
1145+
// is about to push the value into the queue and this isn't
1146+
// the only spinlock in the impl right now.
1147+
thread::yield_now();
1148+
}
1149+
}
11311150
}
11321151
}
11331152
}
@@ -1173,7 +1192,12 @@ impl<T> UnboundedReceiver<T> {
11731192
}
11741193
None => {
11751194
let state = decode_state(inner.state.load(SeqCst));
1176-
if state.is_open || state.num_messages != 0 {
1195+
if state.is_closed() {
1196+
// If closed flag is set AND there are no pending messages
1197+
// it means end of stream
1198+
self.inner = None;
1199+
Poll::Ready(None)
1200+
} else {
11771201
// If queue is open, we need to return Pending
11781202
// to be woken up when new messages arrive.
11791203
// If queue is closed but num_messages is non-zero,
@@ -1182,11 +1206,6 @@ impl<T> UnboundedReceiver<T> {
11821206
// so we need to park until sender unparks the task
11831207
// after queueing the message.
11841208
Poll::Pending
1185-
} else {
1186-
// If closed flag is set AND there are no pending messages
1187-
// it means end of stream
1188-
self.inner = None;
1189-
Poll::Ready(None)
11901209
}
11911210
}
11921211
}
@@ -1240,8 +1259,26 @@ impl<T> Drop for UnboundedReceiver<T> {
12401259
// Drain the channel of all pending messages
12411260
self.close();
12421261
if self.inner.is_some() {
1243-
while let Poll::Ready(Some(..)) = self.next_message() {
1244-
// ...
1262+
loop {
1263+
match self.next_message() {
1264+
Poll::Ready(Some(_)) => {}
1265+
Poll::Ready(None) => break,
1266+
Poll::Pending => {
1267+
let state = decode_state(self.inner.as_ref().unwrap().state.load(SeqCst));
1268+
1269+
// If the channel is closed, then there is no need to park.
1270+
if state.is_closed() {
1271+
break;
1272+
}
1273+
1274+
// TODO: Spinning isn't ideal, it might be worth
1275+
// investigating using a condvar or some other strategy
1276+
// here. That said, if this case is hit, then another thread
1277+
// is about to push the value into the queue and this isn't
1278+
// the only spinlock in the impl right now.
1279+
thread::yield_now();
1280+
}
1281+
}
12451282
}
12461283
}
12471284
}
@@ -1289,6 +1326,12 @@ unsafe impl<T: Send> Sync for UnboundedInner<T> {}
12891326
unsafe impl<T: Send> Send for BoundedInner<T> {}
12901327
unsafe impl<T: Send> Sync for BoundedInner<T> {}
12911328

1329+
impl State {
1330+
fn is_closed(&self) -> bool {
1331+
!self.is_open && self.num_messages == 0
1332+
}
1333+
}
1334+
12921335
/*
12931336
*
12941337
* ===== Helpers =====

futures-channel/tests/mpsc-close.rs

+135-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
use futures::channel::mpsc;
22
use futures::executor::block_on;
3+
use futures::future::Future;
34
use futures::sink::SinkExt;
45
use futures::stream::StreamExt;
5-
use std::sync::Arc;
6+
use futures::task::{Context, Poll};
7+
use std::pin::Pin;
8+
use std::sync::{Arc, Weak};
69
use std::thread;
10+
use std::time::{Duration, Instant};
711

812
#[test]
913
fn smoke() {
@@ -142,3 +146,133 @@ fn single_receiver_drop_closes_channel_and_drains() {
142146
assert!(sender.is_closed());
143147
}
144148
}
149+
150+
// Stress test that `try_send()`s occurring concurrently with receiver
151+
// close/drops don't appear as successful sends.
152+
#[test]
153+
fn stress_try_send_as_receiver_closes() {
154+
const AMT: usize = 10000;
155+
// To provide variable timing characteristics (in the hopes of
156+
// reproducing the collision that leads to a race), we busy-re-poll
157+
// the test MPSC receiver a variable number of times before actually
158+
// stopping. We vary this countdown between 1 and the following
159+
// value.
160+
const MAX_COUNTDOWN: usize = 20;
161+
// When we detect that a successfully sent item is still in the
162+
// queue after a disconnect, we spin for up to 100ms to confirm that
163+
// it is a persistent condition and not a concurrency illusion.
164+
const SPIN_TIMEOUT_S: u64 = 10;
165+
const SPIN_SLEEP_MS: u64 = 10;
166+
struct TestRx {
167+
rx: mpsc::Receiver<Arc<()>>,
168+
// The number of times to query `rx` before dropping it.
169+
poll_count: usize
170+
}
171+
struct TestTask {
172+
command_rx: mpsc::Receiver<TestRx>,
173+
test_rx: Option<mpsc::Receiver<Arc<()>>>,
174+
countdown: usize,
175+
}
176+
impl TestTask {
177+
/// Create a new TestTask
178+
fn new() -> (TestTask, mpsc::Sender<TestRx>) {
179+
let (command_tx, command_rx) = mpsc::channel::<TestRx>(0);
180+
(
181+
TestTask {
182+
command_rx,
183+
test_rx: None,
184+
countdown: 0, // 0 means no countdown is in progress.
185+
},
186+
command_tx,
187+
)
188+
}
189+
}
190+
impl Future for TestTask {
191+
type Output = ();
192+
193+
fn poll(
194+
mut self: Pin<&mut Self>,
195+
cx: &mut Context<'_>,
196+
) -> Poll<Self::Output> {
197+
// Poll the test channel, if one is present.
198+
if let Some(rx) = &mut self.test_rx {
199+
if let Poll::Ready(v) = rx.poll_next_unpin(cx) {
200+
let _ = v.expect("test finished unexpectedly!");
201+
}
202+
self.countdown -= 1;
203+
// Busy-poll until the countdown is finished.
204+
cx.waker().wake_by_ref();
205+
}
206+
// Accept any newly submitted MPSC channels for testing.
207+
match self.command_rx.poll_next_unpin(cx) {
208+
Poll::Ready(Some(TestRx { rx, poll_count })) => {
209+
self.test_rx = Some(rx);
210+
self.countdown = poll_count;
211+
cx.waker().wake_by_ref();
212+
},
213+
Poll::Ready(None) => return Poll::Ready(()),
214+
Poll::Pending => {},
215+
}
216+
if self.countdown == 0 {
217+
// Countdown complete -- drop the Receiver.
218+
self.test_rx = None;
219+
}
220+
Poll::Pending
221+
}
222+
}
223+
let (f, mut cmd_tx) = TestTask::new();
224+
let bg = thread::spawn(move || block_on(f));
225+
for i in 0..AMT {
226+
let (mut test_tx, rx) = mpsc::channel(0);
227+
let poll_count = i % MAX_COUNTDOWN;
228+
cmd_tx.try_send(TestRx { rx, poll_count }).unwrap();
229+
let mut prev_weak: Option<Weak<()>> = None;
230+
let mut attempted_sends = 0;
231+
let mut successful_sends = 0;
232+
loop {
233+
// Create a test item.
234+
let item = Arc::new(());
235+
let weak = Arc::downgrade(&item);
236+
match test_tx.try_send(item) {
237+
Ok(_) => {
238+
prev_weak = Some(weak);
239+
successful_sends += 1;
240+
}
241+
Err(ref e) if e.is_full() => {}
242+
Err(ref e) if e.is_disconnected() => {
243+
// Test for evidence of the race condition.
244+
if let Some(prev_weak) = prev_weak {
245+
if prev_weak.upgrade().is_some() {
246+
// The previously sent item is still allocated.
247+
// However, there appears to be some aspect of the
248+
// concurrency that can legitimately cause the Arc
249+
// to be momentarily valid. Spin for up to 100ms
250+
// waiting for the previously sent item to be
251+
// dropped.
252+
let t0 = Instant::now();
253+
let mut spins = 0;
254+
loop {
255+
if prev_weak.upgrade().is_none() {
256+
break;
257+
}
258+
assert!(t0.elapsed() < Duration::from_secs(SPIN_TIMEOUT_S),
259+
"item not dropped on iteration {} after \
260+
{} sends ({} successful). spin=({})",
261+
i, attempted_sends, successful_sends, spins
262+
);
263+
spins += 1;
264+
thread::sleep(Duration::from_millis(SPIN_SLEEP_MS));
265+
}
266+
}
267+
}
268+
break;
269+
}
270+
Err(ref e) => panic!("unexpected error: {}", e),
271+
}
272+
attempted_sends += 1;
273+
}
274+
}
275+
drop(cmd_tx);
276+
bg.join()
277+
.expect("background thread join");
278+
}

0 commit comments

Comments
 (0)