Skip to content

Commit

Permalink
net: add UdpSocket::peek_sender()
Browse files Browse the repository at this point in the history
  • Loading branch information
abonander committed Mar 7, 2023
1 parent abd92fb commit 287b0ea
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tokio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ num_cpus = { version = "1.8.0", optional = true }
parking_lot = { version = "0.12.0", optional = true }

[target.'cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))'.dependencies]
socket2 = { version = "0.4.4", optional = true, features = [ "all" ] }
socket2 = { version = "0.4.9", optional = true, features = [ "all" ] }

# Currently unstable. The API exposed by these features may be broken at any time.
# Requires `--cfg tokio_unstable` to enable.
Expand Down Expand Up @@ -146,7 +146,7 @@ mockall = "0.11.1"
async-stream = "0.3"

[target.'cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))'.dev-dependencies]
socket2 = "0.4"
socket2 = "0.4.9"
tempfile = "3.1.0"

[target.'cfg(not(all(any(target_arch = "wasm32", target_arch = "wasm64"), target_os = "unknown")))'.dev-dependencies]
Expand Down
69 changes: 69 additions & 0 deletions tokio/src/net/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1331,6 +1331,11 @@ impl UdpSocket {
/// Make sure to always use a sufficiently large buffer to hold the
/// maximum UDP packet size, which can be up to 65536 bytes in size.
///
/// MacOS will return an error if you pass a zero-sized buffer.
///
/// If you're merely interested in learning the sender of the data at the head of the queue,
/// try [`peek_sender`].
///
/// # Examples
///
/// ```no_run
Expand All @@ -1349,6 +1354,8 @@ impl UdpSocket {
/// Ok(())
/// }
/// ```
///
/// [`peek_sender`]: method@Self::peek_sender
pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.io
.registration()
Expand All @@ -1371,6 +1378,11 @@ impl UdpSocket {
/// Make sure to always use a sufficiently large buffer to hold the
/// maximum UDP packet size, which can be up to 65536 bytes in size.
///
/// MacOS will return an error if you pass a zero-sized buffer.
///
/// If you're merely interested in learning the sender of the data at the head of the queue,
/// try [`poll_peek_sender`].
///
/// # Return value
///
/// The function returns:
Expand All @@ -1382,6 +1394,8 @@ impl UdpSocket {
/// # Errors
///
/// This function may encounter any standard I/O error except `WouldBlock`.
///
/// [`poll_peek_sender`]: method@Self::poll_peek_sender
pub fn poll_peek_from(
&self,
cx: &mut Context<'_>,
Expand All @@ -1404,6 +1418,61 @@ impl UdpSocket {
Poll::Ready(Ok(addr))
}

/// Retrieve the sender of the data at the head of the input queue, waiting if empty.
///
/// This is equivalent to calling [`peek_from`] with a zero-sized buffer,
/// but suppresses the `WSAEMSGSIZE` error on Windows and the "invalid argument" error on macOS.
///
/// [`peek_from`]: method@Self::peek_from
pub async fn peek_sender(&self) -> io::Result<SocketAddr> {
self.io
.registration()
.async_io(Interest::READABLE, || self.peek_sender_inner())
.await
}

/// Retrieve the sender of the data at the head of the input queue,
/// scheduling a wakeup if empty.
///
/// This is equivalent to calling [`poll_peek_from`] with a zero-sized buffer,
/// but suppresses the `WSAEMSGSIZE` error on Windows and the "invalid argument" error on macOS.
///
/// # Notes
///
/// Note that on multiple calls to a `poll_*` method in the recv direction, only the
/// `Waker` from the `Context` passed to the most recent call will be scheduled to
/// receive a wakeup.
///
/// [`poll_peek_from`]: method@Self::poll_peek_from
pub fn poll_peek_sender(&self, cx: &mut Context<'_>) -> Poll<io::Result<SocketAddr>> {
self.io
.registration()
.poll_read_io(cx, || self.peek_sender_inner())
}

/// Try to retrieve the sender of the data at the head of the input queue.
///
/// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is
/// returned. This function is usually paired with `readable()`.
pub fn try_peek_sender(&self) -> io::Result<SocketAddr> {
self.io
.registration()
.try_io(Interest::READABLE, || self.peek_sender_inner())
}

#[inline]
fn peek_sender_inner(&self) -> io::Result<SocketAddr> {
self.io.try_io(|| {
self.as_socket()
.peek_sender()?
// May be `None` if the platform doesn't populate the sender for some reason.
// In testing, that only occurred on macOS if you pass a zero-sized buffer,
// but the implementation of `Socket::peek_sender()` covers that.
.as_socket()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "sender not available"))
})
}

/// Gets the value of the `SO_BROADCAST` option for this socket.
///
/// For more information about this option, see [`set_broadcast`].
Expand Down
86 changes: 86 additions & 0 deletions tokio/tests/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,92 @@ async fn send_to_peek_from_poll() -> std::io::Result<()> {
Ok(())
}

#[tokio::test]
async fn peek_sender() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

let sender_addr = sender.local_addr()?;
let receiver_addr = receiver.local_addr()?;

let msg = b"Hello, world!";
sender.send_to(msg, receiver_addr).await?;

let peeked_sender = receiver.peek_sender().await?;
assert_eq!(peeked_sender, sender_addr);

// Assert that `peek_sender()` returns the right sender but
// doesn't remove from the receive queue.
let mut recv_buf = [0u8; 32];
let (read, received_sender) = receiver.recv_from(&mut recv_buf).await?;

assert_eq!(&recv_buf[..read], msg);
assert_eq!(received_sender, peeked_sender);

Ok(())
}

#[tokio::test]
async fn poll_peek_sender() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

let sender_addr = sender.local_addr()?;
let receiver_addr = receiver.local_addr()?;

let msg = b"Hello, world!";
poll_fn(|cx| sender.poll_send_to(cx, msg, receiver_addr)).await?;

let peeked_sender = poll_fn(|cx| receiver.poll_peek_sender(cx)).await?;
assert_eq!(peeked_sender, sender_addr);

// Assert that `poll_peek_sender()` returns the right sender but
// doesn't remove from the receive queue.
let mut recv_buf = [0u8; 32];
let mut read = ReadBuf::new(&mut recv_buf);
let received_sender = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?;

assert_eq!(read.filled(), msg);
assert_eq!(received_sender, peeked_sender);

Ok(())
}

#[tokio::test]
async fn try_peek_sender() -> std::io::Result<()> {
let sender = UdpSocket::bind("127.0.0.1:0").await?;
let receiver = UdpSocket::bind("127.0.0.1:0").await?;

let sender_addr = sender.local_addr()?;
let receiver_addr = receiver.local_addr()?;

let msg = b"Hello, world!";
sender.send_to(msg, receiver_addr).await?;

let peeked_sender = loop {
match receiver.try_peek_sender() {
Ok(peeked_sender) => break peeked_sender,
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
receiver.readable().await?;
}
Err(e) => return Err(e),
}
};

assert_eq!(peeked_sender, sender_addr);

// Assert that `try_peek_sender()` returns the right sender but
// didn't remove from the receive queue.
let mut recv_buf = [0u8; 32];
// We already peeked the sender so there must be data in the receive queue.
let (read, received_sender) = receiver.try_recv_from(&mut recv_buf).unwrap();

assert_eq!(&recv_buf[..read], msg);
assert_eq!(received_sender, peeked_sender);

Ok(())
}

#[tokio::test]
async fn split() -> std::io::Result<()> {
let socket = UdpSocket::bind("127.0.0.1:0").await?;
Expand Down

0 comments on commit 287b0ea

Please sign in to comment.