Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit a43c385

Browse files
committedJul 23, 2023
Add futures_util::io::{pipe,Pipe{Reader,Writer}}
1 parent 81b4a56 commit a43c385

File tree

3 files changed

+510
-0
lines changed

3 files changed

+510
-0
lines changed
 

‎futures-util/src/io/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ pub use self::into_sink::IntoSink;
9191
mod lines;
9292
pub use self::lines::Lines;
9393

94+
mod pipe;
95+
pub use self::pipe::{pipe, PipeReader, PipeWriter};
96+
9497
mod read;
9598
pub use self::read::Read;
9699

‎futures-util/src/io/pipe.rs

+302
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
use core::pin::Pin;
2+
use core::ptr::copy_nonoverlapping;
3+
use core::slice;
4+
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
5+
6+
use alloc::boxed::Box;
7+
use alloc::sync::Arc;
8+
use futures_core::ready;
9+
use futures_core::task::{Context, Poll, Waker};
10+
use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite, Error, ErrorKind, Result};
11+
12+
use crate::task::AtomicWaker;
13+
14+
/// Create a unidirectional bounded pipe for data transfer between asynchronous tasks.
15+
///
16+
/// The internal buffer size is given by `buffer`, which must be non zero. The [`PipeWriter`]
17+
/// returned implements the [`AsyncWrite`] trait, while [`PipeReader`] implements [`AsyncRead`].
18+
///
19+
/// # Panics
20+
///
21+
/// Panics when `buffer` is zero.
22+
#[must_use]
23+
pub fn pipe(buffer: usize) -> (PipeWriter, PipeReader) {
24+
assert!(buffer != 0, "pipe buffer size must be non zero and not usize::MAX");
25+
// If it is `usize::MAX`, the allocation must fail anyway since Rust forbids allocations larger
26+
// than `isize::MAX as usize`. This counts as OOM thus no need to state explicitly.
27+
let len = buffer.saturating_add(1);
28+
let ptr = Box::into_raw(alloc::vec![0u8; len].into_boxed_slice());
29+
let inner = Arc::new(Shared {
30+
len,
31+
buffer: ptr.cast(),
32+
write_pos: AtomicUsize::new(0),
33+
read_pos: AtomicUsize::new(0),
34+
writer_waker: AtomicWaker::new(),
35+
reader_waker: AtomicWaker::new(),
36+
closed: AtomicBool::new(false),
37+
});
38+
(PipeWriter { inner: inner.clone() }, PipeReader { inner })
39+
}
40+
41+
// `read_pos..write_pos` (loop around, same below) contains the buffered content.
42+
// `write_pos..(read_pos-1)` is the empty space for further data.
43+
// Note that index `read_pos-1` is left vacant so that `read_pos == write_pos` if and only if
44+
// the buffer is empty.
45+
//
46+
// Invariants, at any time:
47+
// 1. `read_pos` and `buffer[read_pos..write_pos]` is owned by the read-end.
48+
// Read-end can increment `read_pos` in that range to transfer
49+
// a portion of buffer to the write-end.
50+
// 2. `write_pos` and `buffer[writer_pos..(read_pos-1)]` is owned by the write-end.
51+
// Write-end can increment `write_pos` in that range to transfer
52+
// a portion of buffer to the read-end.
53+
// 3. Read-end can only park (returning Pending) when it observed `read_pos == write_pos` after
54+
// registered the waker.
55+
// 4. Write-end can only park when it observed `write_pos == read_pos-1` after
56+
// registered the waker.
57+
#[derive(Debug)]
58+
struct Shared {
59+
len: usize,
60+
buffer: *mut u8,
61+
read_pos: AtomicUsize,
62+
write_pos: AtomicUsize,
63+
reader_waker: AtomicWaker,
64+
writer_waker: AtomicWaker,
65+
closed: AtomicBool,
66+
}
67+
68+
unsafe impl Send for Shared {}
69+
unsafe impl Sync for Shared {}
70+
71+
impl Drop for Shared {
72+
fn drop(&mut self) {
73+
unsafe {
74+
drop(Box::from_raw(slice::from_raw_parts_mut(self.buffer, self.len)));
75+
}
76+
}
77+
}
78+
79+
impl Shared {
80+
fn poll_read_ready(&self, waker: &Waker) -> Poll<Result<(usize, usize)>> {
81+
// Only mutable by us reader. No synchronization for load.
82+
let data_start = self.read_pos.load(Ordering::Relaxed);
83+
// "Acquire" the bytes for read.
84+
let mut data_end = self.write_pos.load(Ordering::Acquire);
85+
// Fast path.
86+
if data_start == data_end {
87+
// Implicit "Acquite" `write_pos` below.
88+
self.reader_waker.register(waker);
89+
// Double check for readiness.
90+
data_end = self.write_pos.load(Ordering::Acquire);
91+
if data_start == data_end {
92+
// Already "acquire"d by `reader_waker`.
93+
if self.closed.load(Ordering::Relaxed) {
94+
return Poll::Ready(Ok((0, 0)));
95+
}
96+
return Poll::Pending;
97+
}
98+
}
99+
Poll::Ready(Ok((data_start, data_end)))
100+
}
101+
102+
unsafe fn commit_read(&self, new_read_pos: usize) {
103+
// "Release" the bytes just read.
104+
self.read_pos.store(new_read_pos, Ordering::Release);
105+
// Implicit "Release" the `read_pos` change.
106+
self.writer_waker.wake();
107+
}
108+
109+
fn poll_write_ready(&self, waker: &Waker) -> Poll<Result<(usize, usize)>> {
110+
// Only mutable by us writer. No synchronization for load.
111+
let write_start = self.write_pos.load(Ordering::Relaxed);
112+
// "Acquire" the bytes for write.
113+
let mut write_end =
114+
self.read_pos.load(Ordering::Acquire).checked_sub(1).unwrap_or(self.len - 1);
115+
if write_start == write_end {
116+
// Implicit "Acquite" `read_pos` below.
117+
self.writer_waker.register(waker);
118+
// Double check for writeness.
119+
write_end =
120+
self.read_pos.load(Ordering::Acquire).checked_sub(1).unwrap_or(self.len - 1);
121+
if write_start == write_end {
122+
// Already "acquire"d by `reader_waker`.
123+
if self.closed.load(Ordering::Relaxed) {
124+
return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, "pipe closed")));
125+
}
126+
return Poll::Pending;
127+
}
128+
}
129+
Poll::Ready(Ok((write_start, write_end)))
130+
}
131+
132+
unsafe fn commit_write(&self, new_write_pos: usize) {
133+
// "Release" the bytes just written.
134+
self.write_pos.store(new_write_pos, Ordering::Release);
135+
// Implicit "Release" the `write_pos` change.
136+
self.reader_waker.wake();
137+
}
138+
}
139+
140+
/// The write end of a bounded pipe.
141+
///
142+
/// This value is created by the [`pipe`] function.
143+
#[derive(Debug)]
144+
pub struct PipeWriter {
145+
inner: Arc<Shared>,
146+
}
147+
148+
impl Drop for PipeWriter {
149+
fn drop(&mut self) {
150+
self.inner.closed.store(true, Ordering::Relaxed);
151+
// "Release" `closed`.
152+
self.inner.reader_waker.wake();
153+
}
154+
}
155+
156+
impl AsyncWrite for PipeWriter {
157+
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
158+
if buf.is_empty() {
159+
return Poll::Ready(Ok(0));
160+
}
161+
162+
let inner = &*self.inner;
163+
164+
let (write_start, write_end) = ready!(inner.poll_write_ready(cx.waker()))?;
165+
166+
let written = if write_start <= write_end {
167+
let written = buf.len().min(write_end - write_start);
168+
// SAFETY: `buffer[write_pos..read_pos-1]` is owned by us writer.
169+
unsafe {
170+
copy_nonoverlapping(buf.as_ptr(), inner.buffer.add(write_start), written);
171+
}
172+
written
173+
} else {
174+
let written1 = buf.len().min(inner.len - write_start);
175+
let written2 = (buf.len() - written1).min(write_end);
176+
// SAFETY: `buffer[write_pos..]` and `buffer[..read_pos-1]` are owned by us writer.
177+
unsafe {
178+
copy_nonoverlapping(buf.as_ptr(), inner.buffer.add(write_start), written1);
179+
copy_nonoverlapping(buf.as_ptr().add(written1), inner.buffer, written2);
180+
}
181+
written1 + written2
182+
};
183+
184+
let mut new_write_pos = write_start + written;
185+
if new_write_pos >= inner.len {
186+
new_write_pos -= inner.len;
187+
}
188+
189+
unsafe {
190+
inner.commit_write(new_write_pos);
191+
}
192+
193+
Poll::Ready(Ok(written))
194+
}
195+
196+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
197+
Poll::Ready(Ok(()))
198+
}
199+
200+
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
201+
Poll::Ready(Ok(()))
202+
}
203+
}
204+
205+
/// The read end of a bounded pipe.
206+
///
207+
/// This value is created by the [`pipe`] function.
208+
#[derive(Debug)]
209+
pub struct PipeReader {
210+
inner: Arc<Shared>,
211+
}
212+
213+
impl Drop for PipeReader {
214+
fn drop(&mut self) {
215+
self.inner.closed.store(true, Ordering::Relaxed);
216+
// "Release" `closed`.
217+
self.inner.writer_waker.wake();
218+
}
219+
}
220+
221+
impl AsyncRead for PipeReader {
222+
fn poll_read(
223+
self: Pin<&mut Self>,
224+
cx: &mut Context<'_>,
225+
buf: &mut [u8],
226+
) -> Poll<Result<usize>> {
227+
if buf.is_empty() {
228+
return Poll::Ready(Ok(0));
229+
}
230+
231+
let inner = &*self.inner;
232+
233+
let (data_start, data_end) = ready!(inner.poll_read_ready(cx.waker()))?;
234+
235+
let read = if data_start <= data_end {
236+
let read = buf.len().min(data_end - data_start);
237+
// SAFETY: `buffer[read_pos..write_pos]` are owned by us reader.
238+
unsafe {
239+
copy_nonoverlapping(inner.buffer.add(data_start), buf.as_mut_ptr(), read);
240+
}
241+
read
242+
} else {
243+
let read1 = buf.len().min(inner.len - data_start);
244+
let read2 = (buf.len() - read1).min(data_end);
245+
// SAFETY: `buffer[read_pos..]` and `buffer[..write_pos]` are owned by us reader.
246+
unsafe {
247+
copy_nonoverlapping(inner.buffer.add(data_start), buf.as_mut_ptr(), read1);
248+
copy_nonoverlapping(inner.buffer, buf.as_mut_ptr().add(read1), read2);
249+
}
250+
read1 + read2
251+
};
252+
253+
let mut new_read_pos = data_start + read;
254+
if new_read_pos >= inner.len {
255+
new_read_pos -= inner.len;
256+
}
257+
258+
unsafe {
259+
self.inner.commit_read(new_read_pos);
260+
}
261+
262+
Poll::Ready(Ok(read))
263+
}
264+
}
265+
266+
impl AsyncBufRead for PipeReader {
267+
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> {
268+
let inner = &*self.inner;
269+
let (data_start, mut data_end) = ready!(inner.poll_read_ready(cx.waker()))?;
270+
if data_end < data_start {
271+
data_end = inner.len;
272+
}
273+
// SAFETY: `buffer[read_pos..]` is owned by us reader.
274+
let data =
275+
unsafe { slice::from_raw_parts(inner.buffer.add(data_start), data_end - data_start) };
276+
Poll::Ready(Ok(data))
277+
}
278+
279+
fn consume(self: Pin<&mut Self>, amt: usize) {
280+
let inner = &*self.inner;
281+
// Only mutable by us reader. No synchronization for load.
282+
let data_start = inner.read_pos.load(Ordering::Relaxed);
283+
// Can only go forward since the last `poll_fill_buf` in the same thread.
284+
// Does not need to be up-to-date.
285+
let data_end = inner.write_pos.load(Ordering::Relaxed);
286+
287+
let len = if data_start <= data_end {
288+
data_end - data_start
289+
} else {
290+
data_end + inner.len - data_start
291+
};
292+
assert!(amt <= len, "invalid advance");
293+
294+
let mut new_read_pos = data_start + amt;
295+
if new_read_pos >= inner.len {
296+
new_read_pos -= inner.len;
297+
}
298+
unsafe {
299+
inner.commit_read(new_read_pos);
300+
}
301+
}
302+
}

‎futures/tests/io_pipe.rs

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
use futures::future::FutureExt;
2+
use futures::task::Poll;
3+
use futures_core::task::Context;
4+
use futures_executor::block_on;
5+
use futures_io::ErrorKind;
6+
use futures_test::future::FutureTestExt;
7+
use futures_test::task::{new_count_waker, panic_context};
8+
use futures_util::io::{pipe, PipeReader, PipeWriter};
9+
use futures_util::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt};
10+
use static_assertions::assert_impl_all;
11+
12+
trait PollExt<T> {
13+
fn expect_pending(self);
14+
fn expect_ready(self) -> T;
15+
}
16+
17+
impl<T> PollExt<T> for Poll<T> {
18+
#[track_caller]
19+
fn expect_pending(self) {
20+
assert!(self.is_pending());
21+
}
22+
23+
#[track_caller]
24+
fn expect_ready(self) -> T {
25+
match self {
26+
Poll::Ready(v) => v,
27+
Poll::Pending => panic!("should be ready"),
28+
}
29+
}
30+
}
31+
32+
// They have only `Pin<&mut Self>` methods. `&Self` can do nothing. Thus Sync.
33+
assert_impl_all!(PipeReader: Send, Sync, Unpin);
34+
assert_impl_all!(PipeWriter: Send, Sync, Unpin);
35+
36+
#[test]
37+
fn small_write_nonblocking() {
38+
let (mut w, mut r) = pipe(8);
39+
let mut cx = panic_context();
40+
for _ in 0..10 {
41+
let mut buf = [0u8; 10];
42+
assert_eq!(w.write(b"12345").poll_unpin(&mut cx).expect_ready().unwrap(), 5);
43+
assert_eq!(r.read(&mut buf).poll_unpin(&mut cx).expect_ready().unwrap(), 5);
44+
assert_eq!(&buf[..5], b"12345");
45+
}
46+
}
47+
48+
#[test]
49+
fn big_write_nonblocking() {
50+
let (mut w, mut r) = pipe(7);
51+
let mut cx = panic_context();
52+
for _ in 0..10 {
53+
let mut buf = [0u8; 10];
54+
assert_eq!(w.write(b"1234567890").poll_unpin(&mut cx).expect_ready().unwrap(), 7);
55+
assert_eq!(r.read(&mut buf).poll_unpin(&mut cx).expect_ready().unwrap(), 7);
56+
assert_eq!(&buf[..7], b"1234567");
57+
}
58+
}
59+
60+
#[test]
61+
fn reader_blocked() {
62+
let (mut w, mut r) = pipe(8);
63+
let (waker, cnt) = new_count_waker();
64+
let mut cx = Context::from_waker(&waker);
65+
66+
let mut buf = [0u8; 5];
67+
r.read(&mut buf).poll_unpin(&mut cx).expect_pending();
68+
assert_eq!(cnt.get(), 0);
69+
assert_eq!(w.write(b"12345").poll_unpin(&mut cx).expect_ready().unwrap(), 5);
70+
assert_eq!(cnt.get(), 1);
71+
assert_eq!(r.read(&mut buf[..3]).poll_unpin(&mut cx).expect_ready().unwrap(), 3);
72+
assert_eq!(&buf[..3], b"123");
73+
assert_eq!(cnt.get(), 1);
74+
assert_eq!(r.read(&mut buf[..3]).poll_unpin(&mut cx).expect_ready().unwrap(), 2);
75+
assert_eq!(&buf[..2], b"45");
76+
assert_eq!(cnt.get(), 1);
77+
r.read(&mut buf[..3]).poll_unpin(&mut cx).expect_pending();
78+
assert_eq!(cnt.get(), 1);
79+
}
80+
81+
#[test]
82+
fn writer_blocked() {
83+
let (mut w, mut r) = pipe(7);
84+
let (waker, cnt) = new_count_waker();
85+
let mut cx = Context::from_waker(&waker);
86+
let mut buf = [0u8; 10];
87+
88+
assert_eq!(w.write(b"12345").poll_unpin(&mut cx).expect_ready().unwrap(), 5);
89+
assert_eq!(w.write(b"67890").poll_unpin(&mut cx).expect_ready().unwrap(), 2);
90+
assert_eq!(cnt.get(), 0);
91+
w.write(b"xxx").poll_unpin(&mut cx).expect_pending();
92+
assert_eq!(cnt.get(), 0);
93+
assert_eq!(r.read(&mut buf[..3]).poll_unpin(&mut cx).expect_ready().unwrap(), 3);
94+
assert_eq!(&buf[..3], b"123");
95+
assert_eq!(cnt.get(), 1);
96+
assert_eq!(w.write(b"abcde").poll_unpin(&mut cx).expect_ready().unwrap(), 3);
97+
assert_eq!(cnt.get(), 1);
98+
w.write(b"xxx").poll_unpin(&mut cx).expect_pending();
99+
assert_eq!(cnt.get(), 1);
100+
assert_eq!(r.read(&mut buf).poll_unpin(&mut cx).expect_ready().unwrap(), 7);
101+
assert_eq!(&buf[..7], b"4567abc");
102+
assert_eq!(cnt.get(), 2);
103+
}
104+
105+
#[test]
106+
fn reader_closed_notify_writer() {
107+
let (mut w, r) = pipe(4);
108+
let (waker, cnt) = new_count_waker();
109+
let mut cx = Context::from_waker(&waker);
110+
111+
assert_eq!(cnt.get(), 0);
112+
assert_eq!(w.write(b"1234").poll_unpin(&mut cx).expect_ready().unwrap(), 4);
113+
w.write(b"xxx").poll_unpin(&mut cx).expect_pending();
114+
assert_eq!(cnt.get(), 0);
115+
drop(r);
116+
assert_eq!(cnt.get(), 1);
117+
118+
assert_eq!(
119+
w.write(b"xxx").poll_unpin(&mut cx).expect_ready().unwrap_err().kind(),
120+
ErrorKind::BrokenPipe
121+
);
122+
}
123+
124+
#[test]
125+
fn writer_closed_notify_reader() {
126+
let (w, mut r) = pipe(4);
127+
let (waker, cnt) = new_count_waker();
128+
let mut cx = Context::from_waker(&waker);
129+
let mut buf = [0u8; 10];
130+
131+
assert_eq!(cnt.get(), 0);
132+
r.read(&mut buf).poll_unpin(&mut cx).expect_pending();
133+
assert_eq!(cnt.get(), 0);
134+
drop(w);
135+
assert_eq!(cnt.get(), 1);
136+
137+
assert_eq!(r.read(&mut [0u8; 10]).poll_unpin(&mut cx).expect_ready().unwrap(), 0);
138+
}
139+
140+
#[test]
141+
fn writer_closed_with_data() {
142+
let (mut w, mut r) = pipe(4);
143+
let mut cx = panic_context();
144+
let mut buf = [0u8; 10];
145+
146+
assert_eq!(w.write(b"1234").poll_unpin(&mut cx).expect_ready().unwrap(), 4);
147+
drop(w);
148+
assert_eq!(r.read(&mut buf).poll_unpin(&mut cx).expect_ready().unwrap(), 4);
149+
assert_eq!(&buf[..4], b"1234");
150+
assert_eq!(r.read(&mut [0u8; 10]).poll_unpin(&mut cx).expect_ready().unwrap(), 0);
151+
}
152+
153+
#[test]
154+
fn smoke() {
155+
let (mut w, mut r) = pipe(128);
156+
let data = "hello world".repeat(1024);
157+
158+
let reader = std::thread::spawn(|| {
159+
block_on(async move {
160+
let mut buf = String::new();
161+
r.read_to_string(&mut buf).interleave_pending().await.unwrap();
162+
buf
163+
})
164+
});
165+
166+
let writer = std::thread::spawn({
167+
let data = data.clone();
168+
|| {
169+
block_on(async move {
170+
w.write_all(data.as_bytes()).interleave_pending().await.unwrap();
171+
});
172+
}
173+
});
174+
175+
writer.join().unwrap();
176+
let ret = reader.join().unwrap();
177+
assert_eq!(ret, data);
178+
}
179+
180+
#[test]
181+
fn smoke_bufread() {
182+
let (mut w, mut r) = pipe(128);
183+
let data = "hello world\n".repeat(1024);
184+
185+
let reader = std::thread::spawn(|| {
186+
block_on(async move {
187+
let mut buf = String::new();
188+
while r.read_line(&mut buf).await.unwrap() != 0 {}
189+
buf
190+
})
191+
});
192+
193+
let writer = std::thread::spawn({
194+
let data = data.clone();
195+
|| {
196+
block_on(async move {
197+
w.write_all(data.as_bytes()).interleave_pending().await.unwrap();
198+
});
199+
}
200+
});
201+
202+
writer.join().unwrap();
203+
let ret = reader.join().unwrap();
204+
assert_eq!(ret, data);
205+
}

0 commit comments

Comments
 (0)
Please sign in to comment.