diff --git a/tokio/src/io/async_fd.rs b/tokio/src/io/async_fd.rs index e5ad2ab4314..27e3c026196 100644 --- a/tokio/src/io/async_fd.rs +++ b/tokio/src/io/async_fd.rs @@ -335,3 +335,64 @@ impl AsyncFd { self.readiness(mio::Interest::WRITABLE).await } } + +#[derive(Clone, Copy, Debug)] +/// Builder for `AsyncFd` allowing to choose interest on creation. +/// +/// By default unless user chooses anything, all interest is assumed. +pub struct AsyncFdBuilder { + is_read: bool, + is_write: bool, +} + +impl AsyncFdBuilder { + /// Creates new instance. + pub const fn new() -> Self { + Self { + is_read: false, + is_write: false, + } + } + + /// Sets `read` interest + pub const fn read(mut self) -> Self { + self.is_read = true; + self + } + + /// Sets `write` interest + pub const fn write(mut self) -> Self { + self.is_write = true; + self + } + + #[inline] + /// Constructs new `AsyncFd` instance + pub fn build(self, fd: T) -> io::Result> { + Self::build_with_handle(self, fd, Handle::current()) + } + + fn build_with_handle(self, inner: T, handle: Handle) -> io::Result> { + let fd = inner.as_raw_fd(); + let interest = match (self.is_read, self.is_write) { + (true, true) | (false, false) => ALL_INTEREST, + (true, false) => mio::Interest::READABLE, + (false, true) => mio::Interest::WRITABLE, + }; + + let shared = if let Some(inner) = handle.inner() { + inner.add_source(&mut SourceFd(&fd), interest)? + } else { + return Err(io::Error::new( + io::ErrorKind::Other, + "failed to find event loop", + )); + }; + + Ok(AsyncFd { + handle, + shared, + inner: Some(inner), + }) + } +} diff --git a/tokio/src/io/mod.rs b/tokio/src/io/mod.rs index 20d92233c73..56176b1814d 100644 --- a/tokio/src/io/mod.rs +++ b/tokio/src/io/mod.rs @@ -220,7 +220,7 @@ cfg_net_unix! { pub mod unix { //! Asynchronous IO structures specific to Unix-like operating systems. - pub use super::async_fd::{AsyncFd, AsyncFdReadyGuard}; + pub use super::async_fd::{AsyncFd, AsyncFdBuilder, AsyncFdReadyGuard}; } } diff --git a/tokio/tests/io_async_fd.rs b/tokio/tests/io_async_fd.rs index 0303eff6612..7023949dc4a 100644 --- a/tokio/tests/io_async_fd.rs +++ b/tokio/tests/io_async_fd.rs @@ -18,7 +18,7 @@ use nix::unistd::{close, read, write}; use futures::{poll, FutureExt}; -use tokio::io::unix::{AsyncFd, AsyncFdReadyGuard}; +use tokio::io::unix::{AsyncFd, AsyncFdBuilder, AsyncFdReadyGuard}; use tokio_test::{assert_err, assert_pending}; struct TestWaker { @@ -183,6 +183,23 @@ async fn initially_writable() { } } +#[tokio::test] +async fn initially_not_writable() { + let (a, b) = socketpair(); + + let afd_a = AsyncFdBuilder::new().read().build(a).unwrap(); + let afd_b = AsyncFdBuilder::new().read().build(b).unwrap(); + + afd_a.writable().await.unwrap_err(); + afd_b.writable().await.unwrap_err(); + + futures::select_biased! { + _ = tokio::time::sleep(Duration::from_millis(10)).fuse() => {}, + _ = afd_a.readable().fuse() => panic!("Unexpected readable state"), + _ = afd_b.readable().fuse() => panic!("Unexpected readable state"), + } +} + #[tokio::test] async fn reset_readable() { let (a, mut b) = socketpair();