From 0b9acc609445d69aeac638218acb045d15305b91 Mon Sep 17 00:00:00 2001 From: Xinyi Gong Date: Sun, 6 Feb 2022 17:39:28 -0800 Subject: [PATCH] io: make duplex stream cooperative (#4470) Add coop checks on pipe poll_read and poll_write. Fixes: #4470 Refs: #4291, #4300 --- tokio/src/io/util/mem.rs | 16 ++++++++++++++++ tokio/tests/io_mem_stream.rs | 19 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/tokio/src/io/util/mem.rs b/tokio/src/io/util/mem.rs index 4eefe7b26f5..f95c83838c8 100644 --- a/tokio/src/io/util/mem.rs +++ b/tokio/src/io/util/mem.rs @@ -185,6 +185,7 @@ impl AsyncRead for Pipe { cx: &mut task::Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { + ready!(poll_proceed_and_make_progress(cx)); if self.buffer.has_remaining() { let max = self.buffer.remaining().min(buf.remaining()); buf.put_slice(&self.buffer[..max]); @@ -212,6 +213,7 @@ impl AsyncWrite for Pipe { cx: &mut task::Context<'_>, buf: &[u8], ) -> Poll> { + ready!(poll_proceed_and_make_progress(cx)); if self.is_closed { return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into())); } @@ -241,3 +243,17 @@ impl AsyncWrite for Pipe { Poll::Ready(Ok(())) } } + +cfg_coop! { + fn poll_proceed_and_make_progress(cx: &mut task::Context<'_>) -> Poll<()> { + let coop = ready!(crate::coop::poll_proceed(cx)); + coop.made_progress(); + Poll::Ready(()) + } +} + +cfg_not_coop! { + fn poll_proceed_and_make_progress(_: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } +} diff --git a/tokio/tests/io_mem_stream.rs b/tokio/tests/io_mem_stream.rs index 01baa5369c7..4b5e7b7746f 100644 --- a/tokio/tests/io_mem_stream.rs +++ b/tokio/tests/io_mem_stream.rs @@ -100,3 +100,22 @@ async fn max_write_size() { // drop b only after task t1 finishes writing drop(b); } + +#[tokio::test] +async fn duplex_is_cooperative() { + let (mut tx, mut rx) = tokio::io::duplex(1024 * 8); + + tokio::select! { + biased; + + _ = async { + loop { + let buf = [3u8; 4096]; + let _ = tx.write_all(&buf).await; + let mut buf = [0u8; 4096]; + let _ = rx.read(&mut buf).await; + } + } => {}, + _ = tokio::task::yield_now() => {} + } +}