diff --git a/tokio/Cargo.toml b/tokio/Cargo.toml index ae64fd5cff7..0b4e011eff0 100644 --- a/tokio/Cargo.toml +++ b/tokio/Cargo.toml @@ -109,7 +109,7 @@ num_cpus = { version = "1.8.0", optional = true } parking_lot = { version = "0.12.0", optional = true } [target.'cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))'.dependencies] -socket2 = { version = "0.4.4", optional = true, features = [ "all" ] } +socket2 = { version = "0.4.9", optional = true, features = [ "all" ] } # Currently unstable. The API exposed by these features may be broken at any time. # Requires `--cfg tokio_unstable` to enable. @@ -146,7 +146,7 @@ mockall = "0.11.1" async-stream = "0.3" [target.'cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))'.dev-dependencies] -socket2 = "0.4" +socket2 = "0.4.9" tempfile = "3.1.0" [target.'cfg(not(all(any(target_arch = "wasm32", target_arch = "wasm64"), target_os = "unknown")))'.dev-dependencies] diff --git a/tokio/src/net/udp.rs b/tokio/src/net/udp.rs index 213d9149dad..110406252fd 100644 --- a/tokio/src/net/udp.rs +++ b/tokio/src/net/udp.rs @@ -1331,6 +1331,11 @@ impl UdpSocket { /// Make sure to always use a sufficiently large buffer to hold the /// maximum UDP packet size, which can be up to 65536 bytes in size. /// + /// MacOS will return an error if you pass a zero-sized buffer. + /// + /// If you're merely interested in learning the sender of the data at the head of the queue, + /// try [`peek_sender`]. + /// /// # Examples /// /// ```no_run @@ -1349,6 +1354,8 @@ impl UdpSocket { /// Ok(()) /// } /// ``` + /// + /// [`peek_sender`]: method@Self::peek_sender pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> { self.io .registration() @@ -1371,6 +1378,11 @@ impl UdpSocket { /// Make sure to always use a sufficiently large buffer to hold the /// maximum UDP packet size, which can be up to 65536 bytes in size. /// + /// MacOS will return an error if you pass a zero-sized buffer. + /// + /// If you're merely interested in learning the sender of the data at the head of the queue, + /// try [`poll_peek_sender`]. + /// /// # Return value /// /// The function returns: @@ -1382,6 +1394,8 @@ impl UdpSocket { /// # Errors /// /// This function may encounter any standard I/O error except `WouldBlock`. + /// + /// [`poll_peek_sender`]: method@Self::poll_peek_sender pub fn poll_peek_from( &self, cx: &mut Context<'_>, @@ -1404,6 +1418,61 @@ impl UdpSocket { Poll::Ready(Ok(addr)) } + /// Retrieve the sender of the data at the head of the input queue, waiting if empty. + /// + /// This is equivalent to calling [`peek_from`] with a zero-sized buffer, + /// but suppresses the `WSAEMSGSIZE` error on Windows and the "invalid argument" error on macOS. + /// + /// [`peek_from`]: method@Self::peek_from + pub async fn peek_sender(&self) -> io::Result { + self.io + .registration() + .async_io(Interest::READABLE, || self.peek_sender_inner()) + .await + } + + /// Retrieve the sender of the data at the head of the input queue, + /// scheduling a wakeup if empty. + /// + /// This is equivalent to calling [`poll_peek_from`] with a zero-sized buffer, + /// but suppresses the `WSAEMSGSIZE` error on Windows and the "invalid argument" error on macOS. + /// + /// # Notes + /// + /// Note that on multiple calls to a `poll_*` method in the recv direction, only the + /// `Waker` from the `Context` passed to the most recent call will be scheduled to + /// receive a wakeup. + /// + /// [`poll_peek_from`]: method@Self::poll_peek_from + pub fn poll_peek_sender(&self, cx: &mut Context<'_>) -> Poll> { + self.io + .registration() + .poll_read_io(cx, || self.peek_sender_inner()) + } + + /// Try to retrieve the sender of the data at the head of the input queue. + /// + /// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is + /// returned. This function is usually paired with `readable()`. + pub fn try_peek_sender(&self) -> io::Result { + self.io + .registration() + .try_io(Interest::READABLE, || self.peek_sender_inner()) + } + + #[inline] + fn peek_sender_inner(&self) -> io::Result { + self.io.try_io(|| { + self.as_socket() + .peek_sender()? + // May be `None` if the platform doesn't populate the sender for some reason. + // In testing, that only occurred on macOS if you pass a zero-sized buffer, + // but the implementation of `Socket::peek_sender()` covers that. + .as_socket() + .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "sender not available")) + }) + } + /// Gets the value of the `SO_BROADCAST` option for this socket. /// /// For more information about this option, see [`set_broadcast`]. diff --git a/tokio/tests/udp.rs b/tokio/tests/udp.rs index 2b6ab4d2ad2..bd98e4840ce 100644 --- a/tokio/tests/udp.rs +++ b/tokio/tests/udp.rs @@ -134,6 +134,92 @@ async fn send_to_peek_from_poll() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn peek_sender() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + let sender_addr = sender.local_addr()?; + let receiver_addr = receiver.local_addr()?; + + let msg = b"Hello, world!"; + sender.send_to(msg, receiver_addr).await?; + + let peeked_sender = receiver.peek_sender().await?; + assert_eq!(peeked_sender, sender_addr); + + // Assert that `peek_sender()` returns the right sender but + // doesn't remove from the receive queue. + let mut recv_buf = [0u8; 32]; + let (read, received_sender) = receiver.recv_from(&mut recv_buf).await?; + + assert_eq!(&recv_buf[..read], msg); + assert_eq!(received_sender, peeked_sender); + + Ok(()) +} + +#[tokio::test] +async fn poll_peek_sender() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + let sender_addr = sender.local_addr()?; + let receiver_addr = receiver.local_addr()?; + + let msg = b"Hello, world!"; + poll_fn(|cx| sender.poll_send_to(cx, msg, receiver_addr)).await?; + + let peeked_sender = poll_fn(|cx| receiver.poll_peek_sender(cx)).await?; + assert_eq!(peeked_sender, sender_addr); + + // Assert that `poll_peek_sender()` returns the right sender but + // doesn't remove from the receive queue. + let mut recv_buf = [0u8; 32]; + let mut read = ReadBuf::new(&mut recv_buf); + let received_sender = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?; + + assert_eq!(read.filled(), msg); + assert_eq!(received_sender, peeked_sender); + + Ok(()) +} + +#[tokio::test] +async fn try_peek_sender() -> std::io::Result<()> { + let sender = UdpSocket::bind("127.0.0.1:0").await?; + let receiver = UdpSocket::bind("127.0.0.1:0").await?; + + let sender_addr = sender.local_addr()?; + let receiver_addr = receiver.local_addr()?; + + let msg = b"Hello, world!"; + sender.send_to(msg, receiver_addr).await?; + + let peeked_sender = loop { + match receiver.try_peek_sender() { + Ok(peeked_sender) => break peeked_sender, + Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { + receiver.readable().await?; + } + Err(e) => return Err(e), + } + }; + + assert_eq!(peeked_sender, sender_addr); + + // Assert that `try_peek_sender()` returns the right sender but + // didn't remove from the receive queue. + let mut recv_buf = [0u8; 32]; + // We already peeked the sender so there must be data in the receive queue. + let (read, received_sender) = receiver.try_recv_from(&mut recv_buf).unwrap(); + + assert_eq!(&recv_buf[..read], msg); + assert_eq!(received_sender, peeked_sender); + + Ok(()) +} + #[tokio::test] async fn split() -> std::io::Result<()> { let socket = UdpSocket::bind("127.0.0.1:0").await?;