Skip to content

Commit 9207785

Browse files
committed
Add futures_util::io::{pipe,Pipe{Reader,Writer}}
1 parent 81b4a56 commit 9207785

File tree

3 files changed

+510
-0
lines changed

3 files changed

+510
-0
lines changed

futures-util/src/io/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ pub use self::into_sink::IntoSink;
9191
mod lines;
9292
pub use self::lines::Lines;
9393

94+
mod pipe;
95+
pub use self::pipe::{pipe, PipeReader, PipeWriter};
96+
9497
mod read;
9598
pub use self::read::Read;
9699

futures-util/src/io/pipe.rs

+302
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
use core::pin::Pin;
2+
use core::ptr::copy_nonoverlapping;
3+
use core::slice;
4+
use core::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
5+
6+
use alloc::boxed::Box;
7+
use alloc::sync::Arc;
8+
use futures_core::ready;
9+
use futures_core::task::{Context, Poll, Waker};
10+
use futures_io::{AsyncBufRead, AsyncRead, AsyncWrite, Error, ErrorKind, Result};
11+
12+
use crate::task::AtomicWaker;
13+
14+
/// Create a unidirectional bounded pipe for data transfer between asynchronous tasks.
15+
///
16+
/// The internal buffer size is given by `buffer`, which must be non zero. The [`PipeWriter`]
17+
/// returned implements the [`AsyncWrite`] trait, while [`PipeReader`] implements [`AsyncRead`].
18+
///
19+
/// # Panics
20+
///
21+
/// Panics when `buffer` is zero.
22+
#[must_use]
23+
pub fn pipe(buffer: usize) -> (PipeWriter, PipeReader) {
24+
assert!(buffer != 0, "pipe buffer size must be non zero and not usize::MAX");
25+
// If it is `usize::MAX`, the allocation must fail anyway since Rust forbids allocations larger
26+
// than `isize::MAX as usize`. This counts as OOM thus no need to state explicitly.
27+
let len = buffer.saturating_add(1);
28+
let ptr = Box::into_raw(alloc::vec![0u8; len].into_boxed_slice());
29+
let inner = Arc::new(Shared {
30+
len,
31+
buffer: ptr.cast(),
32+
write_pos: AtomicUsize::new(0),
33+
read_pos: AtomicUsize::new(0),
34+
writer_waker: AtomicWaker::new(),
35+
reader_waker: AtomicWaker::new(),
36+
closed: AtomicBool::new(false),
37+
});
38+
(PipeWriter { inner: inner.clone() }, PipeReader { inner })
39+
}
40+
41+
// `read_pos..write_pos` (loop around, same below) contains the buffered content.
42+
// `write_pos..(read_pos-1)` is the empty space for further data.
43+
// Note that index `read_pos-1` is left vacant so that `read_pos == write_pos` if and only if
44+
// the buffer is empty.
45+
//
46+
// Invariants, at any time:
47+
// 1. `read_pos` and `buffer[read_pos..write_pos]` is owned by the read-end.
48+
// Read-end can increment `read_pos` in that range to transfer
49+
// a portion of buffer to the write-end.
50+
// 2. `write_pos` and `buffer[writer_pos..(read_pos-1)]` is owned by the write-end.
51+
// Write-end can increment `write_pos` in that range to transfer
52+
// a portion of buffer to the read-end.
53+
// 3. Read-end can only park (returning Pending) when it observed `read_pos == write_pos` after
54+
// registered the waker.
55+
// 4. Write-end can only park when it observed `write_pos == read_pos-1` after
56+
// registered the waker.
57+
#[derive(Debug)]
58+
struct Shared {
59+
len: usize,
60+
buffer: *mut u8,
61+
read_pos: AtomicUsize,
62+
write_pos: AtomicUsize,
63+
reader_waker: AtomicWaker,
64+
writer_waker: AtomicWaker,
65+
closed: AtomicBool,
66+
}
67+
68+
unsafe impl Send for Shared {}
69+
unsafe impl Sync for Shared {}
70+
71+
impl Drop for Shared {
72+
fn drop(&mut self) {
73+
unsafe {
74+
drop(Box::from_raw(slice::from_raw_parts_mut(self.buffer, self.len)));
75+
}
76+
}
77+
}
78+
79+
impl Shared {
80+
fn poll_read_ready(&self, waker: &Waker) -> Poll<Result<(usize, usize)>> {
81+
// Only mutable by us reader. No synchronization for load.
82+
let data_start = self.read_pos.load(Ordering::Relaxed);
83+
// "Acquire" the bytes for read.
84+
let mut data_end = self.write_pos.load(Ordering::Acquire);
85+
// Fast path.
86+
if data_start == data_end {
87+
// Implicit "Acquite" `write_pos` below.
88+
self.reader_waker.register(waker);
89+
// Double check for readiness.
90+
data_end = self.write_pos.load(Ordering::Acquire);
91+
if data_start == data_end {
92+
// Already "acquire"d by `reader_waker`.
93+
if self.closed.load(Ordering::Relaxed) {
94+
return Poll::Ready(Ok((0, 0)));
95+
}
96+
return Poll::Pending;
97+
}
98+
}
99+
Poll::Ready(Ok((data_start, data_end)))
100+
}
101+
102+
unsafe fn commit_read(&self, new_read_pos: usize) {
103+
// "Release" the bytes just read.
104+
self.read_pos.store(new_read_pos, Ordering::Release);
105+
// Implicit "Release" the `read_pos` change.
106+
self.writer_waker.wake();
107+
}
108+
109+
fn poll_write_ready(&self, waker: &Waker) -> Poll<Result<(usize, usize)>> {
110+
// Only mutable by us writer. No synchronization for load.
111+
let write_start = self.write_pos.load(Ordering::Relaxed);
112+
// "Acquire" the bytes for write.
113+
let mut write_end =
114+
self.read_pos.load(Ordering::Acquire).checked_sub(1).unwrap_or(self.len - 1);
115+
if write_start == write_end {
116+
// Implicit "Acquite" `read_pos` below.
117+
self.writer_waker.register(waker);
118+
// Double check for writeness.
119+
write_end =
120+
self.read_pos.load(Ordering::Acquire).checked_sub(1).unwrap_or(self.len - 1);
121+
if write_start == write_end {
122+
// Already "acquire"d by `reader_waker`.
123+
if self.closed.load(Ordering::Relaxed) {
124+
return Poll::Ready(Err(Error::new(ErrorKind::BrokenPipe, "pipe closed")));
125+
}
126+
return Poll::Pending;
127+
}
128+
}
129+
Poll::Ready(Ok((write_start, write_end)))
130+
}
131+
132+
unsafe fn commit_write(&self, new_write_pos: usize) {
133+
// "Release" the bytes just written.
134+
self.write_pos.store(new_write_pos, Ordering::Release);
135+
// Implicit "Release" the `write_pos` change.
136+
self.reader_waker.wake();
137+
}
138+
}
139+
140+
/// The write end of a bounded pipe.
141+
///
142+
/// This value is created by the [`pipe`] function.
143+
#[derive(Debug)]
144+
pub struct PipeWriter {
145+
inner: Arc<Shared>,
146+
}
147+
148+
impl Drop for PipeWriter {
149+
fn drop(&mut self) {
150+
self.inner.closed.store(true, Ordering::Relaxed);
151+
// "Release" `closed`.
152+
self.inner.reader_waker.wake();
153+
}
154+
}
155+
156+
impl AsyncWrite for PipeWriter {
157+
fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize>> {
158+
if buf.is_empty() {
159+
return Poll::Ready(Ok(0));
160+
}
161+
162+
let inner = &*self.inner;
163+
164+
let (write_start, write_end) = ready!(inner.poll_write_ready(cx.waker()))?;
165+
166+
let written = if write_start <= write_end {
167+
let written = buf.len().min(write_end - write_start);
168+
// SAFETY: `buffer[write_pos..read_pos-1]` is owned by us writer.
169+
unsafe {
170+
copy_nonoverlapping(buf.as_ptr(), inner.buffer.add(write_start), written);
171+
}
172+
written
173+
} else {
174+
let written1 = buf.len().min(inner.len - write_start);
175+
let written2 = (buf.len() - written1).min(write_end);
176+
// SAFETY: `buffer[write_pos..]` and `buffer[..read_pos-1]` are owned by us writer.
177+
unsafe {
178+
copy_nonoverlapping(buf.as_ptr(), inner.buffer.add(write_start), written1);
179+
copy_nonoverlapping(buf.as_ptr().add(written1), inner.buffer, written2);
180+
}
181+
written1 + written2
182+
};
183+
184+
let mut new_write_pos = write_start + written;
185+
if new_write_pos >= inner.len {
186+
new_write_pos -= inner.len;
187+
}
188+
189+
unsafe {
190+
inner.commit_write(new_write_pos);
191+
}
192+
193+
Poll::Ready(Ok(written))
194+
}
195+
196+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
197+
Poll::Ready(Ok(()))
198+
}
199+
200+
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<()>> {
201+
Poll::Ready(Ok(()))
202+
}
203+
}
204+
205+
/// The read end of a bounded pipe.
206+
///
207+
/// This value is created by the [`pipe`] function.
208+
#[derive(Debug)]
209+
pub struct PipeReader {
210+
inner: Arc<Shared>,
211+
}
212+
213+
impl Drop for PipeReader {
214+
fn drop(&mut self) {
215+
self.inner.closed.store(true, Ordering::Relaxed);
216+
// "Release" `closed`.
217+
self.inner.writer_waker.wake();
218+
}
219+
}
220+
221+
impl AsyncRead for PipeReader {
222+
fn poll_read(
223+
self: Pin<&mut Self>,
224+
cx: &mut Context<'_>,
225+
buf: &mut [u8],
226+
) -> Poll<Result<usize>> {
227+
if buf.is_empty() {
228+
return Poll::Ready(Ok(0));
229+
}
230+
231+
let inner = &*self.inner;
232+
233+
let (data_start, data_end) = ready!(inner.poll_read_ready(cx.waker()))?;
234+
235+
let read = if data_start <= data_end {
236+
let read = buf.len().min(data_end - data_start);
237+
// SAFETY: `buffer[read_pos..write_pos]` are owned by us reader.
238+
unsafe {
239+
copy_nonoverlapping(inner.buffer.add(data_start), buf.as_mut_ptr(), read);
240+
}
241+
read
242+
} else {
243+
let read1 = buf.len().min(inner.len - data_start);
244+
let read2 = (buf.len() - read1).min(data_end);
245+
// SAFETY: `buffer[read_pos..]` and `buffer[..write_pos]` are owned by us reader.
246+
unsafe {
247+
copy_nonoverlapping(inner.buffer.add(data_start), buf.as_mut_ptr(), read1);
248+
copy_nonoverlapping(inner.buffer, buf.as_mut_ptr().add(read1), read2);
249+
}
250+
read1 + read2
251+
};
252+
253+
let mut new_read_pos = data_start + read;
254+
if new_read_pos >= inner.len {
255+
new_read_pos -= inner.len;
256+
}
257+
258+
unsafe {
259+
self.inner.commit_read(new_read_pos);
260+
}
261+
262+
Poll::Ready(Ok(read))
263+
}
264+
}
265+
266+
impl AsyncBufRead for PipeReader {
267+
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<&[u8]>> {
268+
let inner = &*self.inner;
269+
let (data_start, mut data_end) = ready!(inner.poll_read_ready(cx.waker()))?;
270+
if data_end < data_start {
271+
data_end = inner.len;
272+
}
273+
// SAFETY: `buffer[read_pos..]` is owned by us reader.
274+
let data =
275+
unsafe { slice::from_raw_parts(inner.buffer.add(data_start), data_end - data_start) };
276+
Poll::Ready(Ok(data))
277+
}
278+
279+
fn consume(self: Pin<&mut Self>, amt: usize) {
280+
let inner = &*self.inner;
281+
// Only mutable by us reader. No synchronization for load.
282+
let data_start = inner.read_pos.load(Ordering::Relaxed);
283+
// Can only go forward since the last `poll_fill_buf` in the same thread.
284+
// Does not need to be up-to-date.
285+
let data_end = inner.write_pos.load(Ordering::Relaxed);
286+
287+
let len = if data_start <= data_end {
288+
data_end - data_start
289+
} else {
290+
data_end + inner.len - data_start
291+
};
292+
assert!(amt <= len, "invalid advance");
293+
294+
let mut new_read_pos = data_start + amt;
295+
if new_read_pos >= inner.len {
296+
new_read_pos -= inner.len;
297+
}
298+
unsafe {
299+
inner.commit_read(new_read_pos);
300+
}
301+
}
302+
}

0 commit comments

Comments
 (0)