diff --git a/tokio/src/net/tcp/stream.rs b/tokio/src/net/tcp/stream.rs index 2ac37a2bc85..8a157e1c29d 100644 --- a/tokio/src/net/tcp/stream.rs +++ b/tokio/src/net/tcp/stream.rs @@ -356,6 +356,17 @@ impl TcpStream { Ok(()) } + /// Polls for read readiness. + /// + /// This function is intended for cases where creating and pinning a future + /// via [`readable`] is not feasible. Where possible, using [`readable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// [`readable`]: method@Self::readable + pub fn poll_read_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.io.registration().poll_read_ready(cx).map_ok(|_| ()) + } + /// Try to read data from the stream into the provided buffer, returning how /// many bytes were read. /// @@ -467,6 +478,17 @@ impl TcpStream { Ok(()) } + /// Polls for write readiness. + /// + /// This function is intended for cases where creating and pinning a future + /// via [`writable`] is not feasible. Where possible, using [`writable`] is + /// preferred, as this supports polling from multiple tasks at once. + /// + /// [`writable`]: method@Self::writable + pub fn poll_write_ready(&self, cx: &mut Context<'_>) -> Poll> { + self.io.registration().poll_write_ready(cx).map_ok(|_| ()) + } + /// Try to write a buffer to the stream, returning how many bytes were /// written. /// diff --git a/tokio/tests/tcp_stream.rs b/tokio/tests/tcp_stream.rs index 784ade8a411..84d58dc511b 100644 --- a/tokio/tests/tcp_stream.rs +++ b/tokio/tests/tcp_stream.rs @@ -1,12 +1,16 @@ #![warn(rust_2018_idioms)] #![cfg(feature = "full")] -use tokio::io::Interest; +use tokio::io::{AsyncReadExt, AsyncWriteExt, Interest}; use tokio::net::{TcpListener, TcpStream}; +use tokio::try_join; use tokio_test::task; -use tokio_test::{assert_pending, assert_ready_ok}; +use tokio_test::{assert_ok, assert_pending, assert_ready_ok}; use std::io; +use std::task::Poll; + +use futures::future::poll_fn; #[tokio::test] async fn try_read_write() { @@ -110,3 +114,107 @@ fn buffer_not_included_in_future() { let n = mem::size_of_val(&fut); assert!(n < 1000); } + +macro_rules! assert_readable_by_polling { + ($stream:expr) => { + assert_ok!(poll_fn(|cx| $stream.poll_read_ready(cx)).await); + }; +} + +macro_rules! assert_not_readable_by_polling { + ($stream:expr) => { + poll_fn(|cx| { + assert_pending!($stream.poll_read_ready(cx)); + Poll::Ready(()) + }) + .await; + }; +} + +macro_rules! assert_writable_by_polling { + ($stream:expr) => { + assert_ok!(poll_fn(|cx| $stream.poll_write_ready(cx)).await); + }; +} + +macro_rules! assert_not_writable_by_polling { + ($stream:expr) => { + poll_fn(|cx| { + assert_pending!($stream.poll_write_ready(cx)); + Poll::Ready(()) + }) + .await; + }; +} + +#[tokio::test] +async fn poll_read_ready() { + let (mut client, mut server) = create_pair().await; + + // Initial state - not readable. + assert_not_readable_by_polling!(server); + + // There is data in the buffer - readable. + assert_ok!(client.write_all(b"ping").await); + assert_readable_by_polling!(server); + + // Readable until calls to `poll_read` return `Poll::Pending`. + let mut buf = [0u8; 4]; + assert_ok!(server.read_exact(&mut buf).await); + assert_readable_by_polling!(server); + read_until_pending(&mut server); + assert_not_readable_by_polling!(server); + + // Detect the client disconnect. + drop(client); + assert_readable_by_polling!(server); +} + +#[tokio::test] +async fn poll_write_ready() { + let (mut client, server) = create_pair().await; + + // Initial state - writable. + assert_writable_by_polling!(client); + + // No space to write - not writable. + write_until_pending(&mut client); + assert_not_writable_by_polling!(client); + + // Detect the server disconnect. + drop(server); + assert_writable_by_polling!(client); +} + +async fn create_pair() -> (TcpStream, TcpStream) { + let listener = assert_ok!(TcpListener::bind("127.0.0.1:0").await); + let addr = assert_ok!(listener.local_addr()); + let (client, (server, _)) = assert_ok!(try_join!(TcpStream::connect(&addr), listener.accept())); + (client, server) +} + +fn read_until_pending(stream: &mut TcpStream) { + let mut buf = vec![0u8; 1024 * 1024]; + loop { + match stream.try_read(&mut buf) { + Ok(_) => (), + Err(err) => { + assert_eq!(err.kind(), io::ErrorKind::WouldBlock); + break; + } + } + } +} + +fn write_until_pending(stream: &mut TcpStream) { + let buf = vec![0u8; 1024 * 1024]; + loop { + match stream.try_write(&buf) { + Ok(_) => (), + Err(err) => { + assert_eq!(err.kind(), io::ErrorKind::WouldBlock); + break; + } + } + } +}