Skip to content

Commit 11d3655

Browse files
committed
net: add UdpSocket::peek_sender()
closes tokio-rs#5491
1 parent ff2f286 commit 11d3655

File tree

3 files changed

+157
-2
lines changed

3 files changed

+157
-2
lines changed

tokio/Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ num_cpus = { version = "1.8.0", optional = true }
109109
parking_lot = { version = "0.12.0", optional = true }
110110

111111
[target.'cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))'.dependencies]
112-
socket2 = { version = "0.4.4", optional = true, features = [ "all" ] }
112+
socket2 = { version = "0.4.9", optional = true, features = [ "all" ] }
113113

114114
# Currently unstable. The API exposed by these features may be broken at any time.
115115
# Requires `--cfg tokio_unstable` to enable.
@@ -146,7 +146,7 @@ mockall = "0.11.1"
146146
async-stream = "0.3"
147147

148148
[target.'cfg(not(any(target_arch = "wasm32", target_arch = "wasm64")))'.dev-dependencies]
149-
socket2 = "0.4"
149+
socket2 = "0.4.9"
150150
tempfile = "3.1.0"
151151

152152
[target.'cfg(not(all(any(target_arch = "wasm32", target_arch = "wasm64"), target_os = "unknown")))'.dev-dependencies]

tokio/src/net/udp.rs

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,11 @@ impl UdpSocket {
13311331
/// Make sure to always use a sufficiently large buffer to hold the
13321332
/// maximum UDP packet size, which can be up to 65536 bytes in size.
13331333
///
1334+
/// MacOS will return an error if you pass a zero-sized buffer.
1335+
///
1336+
/// If you're merely interested in learning the sender of the data at the head of the queue,
1337+
/// try [`peek_sender`].
1338+
///
13341339
/// # Examples
13351340
///
13361341
/// ```no_run
@@ -1349,6 +1354,8 @@ impl UdpSocket {
13491354
/// Ok(())
13501355
/// }
13511356
/// ```
1357+
///
1358+
/// [`peek_sender`]: method@Self::peek_sender
13521359
pub async fn peek_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
13531360
self.io
13541361
.registration()
@@ -1371,6 +1378,11 @@ impl UdpSocket {
13711378
/// Make sure to always use a sufficiently large buffer to hold the
13721379
/// maximum UDP packet size, which can be up to 65536 bytes in size.
13731380
///
1381+
/// MacOS will return an error if you pass a zero-sized buffer.
1382+
///
1383+
/// If you're merely interested in learning the sender of the data at the head of the queue,
1384+
/// try [`poll_peek_sender`].
1385+
///
13741386
/// # Return value
13751387
///
13761388
/// The function returns:
@@ -1382,6 +1394,8 @@ impl UdpSocket {
13821394
/// # Errors
13831395
///
13841396
/// This function may encounter any standard I/O error except `WouldBlock`.
1397+
///
1398+
/// [`poll_peek_sender`]: method@Self::poll_peek_sender
13851399
pub fn poll_peek_from(
13861400
&self,
13871401
cx: &mut Context<'_>,
@@ -1404,6 +1418,61 @@ impl UdpSocket {
14041418
Poll::Ready(Ok(addr))
14051419
}
14061420

1421+
/// Retrieve the sender of the data at the head of the input queue, waiting if empty.
1422+
///
1423+
/// This is equivalent to calling [`peek_from`] with a zero-sized buffer,
1424+
/// but suppresses the `WSAEMSGSIZE` error on Windows and the "invalid argument" error on macOS.
1425+
///
1426+
/// [`peek_from`]: method@Self::peek_from
1427+
pub async fn peek_sender(&self) -> io::Result<SocketAddr> {
1428+
self.io
1429+
.registration()
1430+
.async_io(Interest::READABLE, || self.peek_sender_inner())
1431+
.await
1432+
}
1433+
1434+
/// Retrieve the sender of the data at the head of the input queue,
1435+
/// scheduling a wakeup if empty.
1436+
///
1437+
/// This is equivalent to calling [`poll_peek_from`] with a zero-sized buffer,
1438+
/// but suppresses the `WSAEMSGSIZE` error on Windows and the "invalid argument" error on macOS.
1439+
///
1440+
/// # Notes
1441+
///
1442+
/// Note that on multiple calls to a `poll_*` method in the recv direction, only the
1443+
/// `Waker` from the `Context` passed to the most recent call will be scheduled to
1444+
/// receive a wakeup.
1445+
///
1446+
/// [`poll_peek_from`]: method@Self::poll_peek_from
1447+
pub fn poll_peek_sender(&self, cx: &mut Context<'_>) -> Poll<io::Result<SocketAddr>> {
1448+
self.io
1449+
.registration()
1450+
.poll_read_io(cx, || self.peek_sender_inner())
1451+
}
1452+
1453+
/// Try to retrieve the sender of the data at the head of the input queue.
1454+
///
1455+
/// When there is no pending data, `Err(io::ErrorKind::WouldBlock)` is
1456+
/// returned. This function is usually paired with `readable()`.
1457+
pub fn try_peek_sender(&self) -> io::Result<SocketAddr> {
1458+
self.io
1459+
.registration()
1460+
.try_io(Interest::READABLE, || self.peek_sender_inner())
1461+
}
1462+
1463+
#[inline]
1464+
fn peek_sender_inner(&self) -> io::Result<SocketAddr> {
1465+
self.io.try_io(|| {
1466+
self.as_socket()
1467+
.peek_sender()?
1468+
// May be `None` if the platform doesn't populate the sender for some reason.
1469+
// In testing, that only occurred on macOS if you pass a zero-sized buffer,
1470+
// but the implementation of `Socket::peek_sender()` covers that.
1471+
.as_socket()
1472+
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "sender not available"))
1473+
})
1474+
}
1475+
14071476
/// Gets the value of the `SO_BROADCAST` option for this socket.
14081477
///
14091478
/// For more information about this option, see [`set_broadcast`].

tokio/tests/udp.rs

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,92 @@ async fn send_to_peek_from_poll() -> std::io::Result<()> {
134134
Ok(())
135135
}
136136

137+
#[tokio::test]
138+
async fn peek_sender() -> std::io::Result<()> {
139+
let sender = UdpSocket::bind("127.0.0.1:0").await?;
140+
let receiver = UdpSocket::bind("127.0.0.1:0").await?;
141+
142+
let sender_addr = sender.local_addr()?;
143+
let receiver_addr = receiver.local_addr()?;
144+
145+
let msg = b"Hello, world!";
146+
sender.send_to(msg, receiver_addr).await?;
147+
148+
let peeked_sender = receiver.peek_sender().await?;
149+
assert_eq!(peeked_sender, sender_addr);
150+
151+
// Assert that `peek_sender()` returns the right sender but
152+
// doesn't remove from the receive queue.
153+
let mut recv_buf = [0u8; 32];
154+
let (read, received_sender) = receiver.recv_from(&mut recv_buf).await?;
155+
156+
assert_eq!(&recv_buf[..read], msg);
157+
assert_eq!(received_sender, peeked_sender);
158+
159+
Ok(())
160+
}
161+
162+
#[tokio::test]
163+
async fn poll_peek_sender() -> std::io::Result<()> {
164+
let sender = UdpSocket::bind("127.0.0.1:0").await?;
165+
let receiver = UdpSocket::bind("127.0.0.1:0").await?;
166+
167+
let sender_addr = sender.local_addr()?;
168+
let receiver_addr = receiver.local_addr()?;
169+
170+
let msg = b"Hello, world!";
171+
poll_fn(|cx| sender.poll_send_to(cx, msg, receiver_addr)).await?;
172+
173+
let peeked_sender = poll_fn(|cx| receiver.poll_peek_sender(cx)).await?;
174+
assert_eq!(peeked_sender, sender_addr);
175+
176+
// Assert that `poll_peek_sender()` returns the right sender but
177+
// doesn't remove from the receive queue.
178+
let mut recv_buf = [0u8; 32];
179+
let mut read = ReadBuf::new(&mut recv_buf);
180+
let received_sender = poll_fn(|cx| receiver.poll_recv_from(cx, &mut read)).await?;
181+
182+
assert_eq!(read.filled(), msg);
183+
assert_eq!(received_sender, peeked_sender);
184+
185+
Ok(())
186+
}
187+
188+
#[tokio::test]
189+
async fn try_peek_sender() -> std::io::Result<()> {
190+
let sender = UdpSocket::bind("127.0.0.1:0").await?;
191+
let receiver = UdpSocket::bind("127.0.0.1:0").await?;
192+
193+
let sender_addr = sender.local_addr()?;
194+
let receiver_addr = receiver.local_addr()?;
195+
196+
let msg = b"Hello, world!";
197+
sender.send_to(msg, receiver_addr).await?;
198+
199+
let peeked_sender = loop {
200+
match receiver.try_peek_sender() {
201+
Ok(peeked_sender) => break peeked_sender,
202+
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
203+
receiver.readable().await?;
204+
}
205+
Err(e) => return Err(e),
206+
}
207+
};
208+
209+
assert_eq!(peeked_sender, sender_addr);
210+
211+
// Assert that `try_peek_sender()` returns the right sender but
212+
// didn't remove from the receive queue.
213+
let mut recv_buf = [0u8; 32];
214+
// We already peeked the sender so there must be data in the receive queue.
215+
let (read, received_sender) = receiver.try_recv_from(&mut recv_buf).unwrap();
216+
217+
assert_eq!(&recv_buf[..read], msg);
218+
assert_eq!(received_sender, peeked_sender);
219+
220+
Ok(())
221+
}
222+
137223
#[tokio::test]
138224
async fn split() -> std::io::Result<()> {
139225
let socket = UdpSocket::bind("127.0.0.1:0").await?;

0 commit comments

Comments
 (0)