Skip to content

Commit adad8fc

Browse files
authored
io: add a copy_bidirectional utility (#3572)
1 parent 08f1b67 commit adad8fc

File tree

5 files changed

+329
-51
lines changed

5 files changed

+329
-51
lines changed

tokio/src/io/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ cfg_io_util! {
246246
pub(crate) mod seek;
247247
pub(crate) mod util;
248248
pub use util::{
249-
copy, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
249+
copy, copy_bidirectional, copy_buf, duplex, empty, repeat, sink, AsyncBufReadExt, AsyncReadExt, AsyncSeekExt, AsyncWriteExt,
250250
BufReader, BufStream, BufWriter, DuplexStream, Empty, Lines, Repeat, Sink, Split, Take,
251251
};
252252
}

tokio/src/io/util/copy.rs

+78-50
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,85 @@ use std::io;
55
use std::pin::Pin;
66
use std::task::{Context, Poll};
77

8+
#[derive(Debug)]
9+
pub(super) struct CopyBuffer {
10+
read_done: bool,
11+
pos: usize,
12+
cap: usize,
13+
amt: u64,
14+
buf: Box<[u8]>,
15+
}
16+
17+
impl CopyBuffer {
18+
pub(super) fn new() -> Self {
19+
Self {
20+
read_done: false,
21+
pos: 0,
22+
cap: 0,
23+
amt: 0,
24+
buf: vec![0; 2048].into_boxed_slice(),
25+
}
26+
}
27+
28+
pub(super) fn poll_copy<R, W>(
29+
&mut self,
30+
cx: &mut Context<'_>,
31+
mut reader: Pin<&mut R>,
32+
mut writer: Pin<&mut W>,
33+
) -> Poll<io::Result<u64>>
34+
where
35+
R: AsyncRead + ?Sized,
36+
W: AsyncWrite + ?Sized,
37+
{
38+
loop {
39+
// If our buffer is empty, then we need to read some data to
40+
// continue.
41+
if self.pos == self.cap && !self.read_done {
42+
let me = &mut *self;
43+
let mut buf = ReadBuf::new(&mut me.buf);
44+
ready!(reader.as_mut().poll_read(cx, &mut buf))?;
45+
let n = buf.filled().len();
46+
if n == 0 {
47+
self.read_done = true;
48+
} else {
49+
self.pos = 0;
50+
self.cap = n;
51+
}
52+
}
53+
54+
// If our buffer has some data, let's write it out!
55+
while self.pos < self.cap {
56+
let me = &mut *self;
57+
let i = ready!(writer.as_mut().poll_write(cx, &me.buf[me.pos..me.cap]))?;
58+
if i == 0 {
59+
return Poll::Ready(Err(io::Error::new(
60+
io::ErrorKind::WriteZero,
61+
"write zero byte into writer",
62+
)));
63+
} else {
64+
self.pos += i;
65+
self.amt += i as u64;
66+
}
67+
}
68+
69+
// If we've written all the data and we've seen EOF, flush out the
70+
// data and finish the transfer.
71+
if self.pos == self.cap && self.read_done {
72+
ready!(writer.as_mut().poll_flush(cx))?;
73+
return Poll::Ready(Ok(self.amt));
74+
}
75+
}
76+
}
77+
}
78+
879
/// A future that asynchronously copies the entire contents of a reader into a
980
/// writer.
1081
#[derive(Debug)]
1182
#[must_use = "futures do nothing unless you `.await` or poll them"]
1283
struct Copy<'a, R: ?Sized, W: ?Sized> {
1384
reader: &'a mut R,
14-
read_done: bool,
1585
writer: &'a mut W,
16-
pos: usize,
17-
cap: usize,
18-
amt: u64,
19-
buf: Box<[u8]>,
86+
buf: CopyBuffer,
2087
}
2188

2289
cfg_io_util! {
@@ -35,8 +102,8 @@ cfg_io_util! {
35102
///
36103
/// # Errors
37104
///
38-
/// The returned future will finish with an error will return an error
39-
/// immediately if any call to `poll_read` or `poll_write` returns an error.
105+
/// The returned future will return an error immediately if any call to
106+
/// `poll_read` or `poll_write` returns an error.
40107
///
41108
/// # Examples
42109
///
@@ -60,12 +127,8 @@ cfg_io_util! {
60127
{
61128
Copy {
62129
reader,
63-
read_done: false,
64130
writer,
65-
amt: 0,
66-
pos: 0,
67-
cap: 0,
68-
buf: vec![0; 2048].into_boxed_slice(),
131+
buf: CopyBuffer::new()
69132
}.await
70133
}
71134
}
@@ -78,44 +141,9 @@ where
78141
type Output = io::Result<u64>;
79142

80143
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<u64>> {
81-
loop {
82-
// If our buffer is empty, then we need to read some data to
83-
// continue.
84-
if self.pos == self.cap && !self.read_done {
85-
let me = &mut *self;
86-
let mut buf = ReadBuf::new(&mut me.buf);
87-
ready!(Pin::new(&mut *me.reader).poll_read(cx, &mut buf))?;
88-
let n = buf.filled().len();
89-
if n == 0 {
90-
self.read_done = true;
91-
} else {
92-
self.pos = 0;
93-
self.cap = n;
94-
}
95-
}
144+
let me = &mut *self;
96145

97-
// If our buffer has some data, let's write it out!
98-
while self.pos < self.cap {
99-
let me = &mut *self;
100-
let i = ready!(Pin::new(&mut *me.writer).poll_write(cx, &me.buf[me.pos..me.cap]))?;
101-
if i == 0 {
102-
return Poll::Ready(Err(io::Error::new(
103-
io::ErrorKind::WriteZero,
104-
"write zero byte into writer",
105-
)));
106-
} else {
107-
self.pos += i;
108-
self.amt += i as u64;
109-
}
110-
}
111-
112-
// If we've written all the data and we've seen EOF, flush out the
113-
// data and finish the transfer.
114-
if self.pos == self.cap && self.read_done {
115-
let me = &mut *self;
116-
ready!(Pin::new(&mut *me.writer).poll_flush(cx))?;
117-
return Poll::Ready(Ok(self.amt));
118-
}
119-
}
146+
me.buf
147+
.poll_copy(cx, Pin::new(&mut *me.reader), Pin::new(&mut *me.writer))
120148
}
121149
}
+119
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
use super::copy::CopyBuffer;
2+
3+
use crate::io::{AsyncRead, AsyncWrite};
4+
5+
use std::future::Future;
6+
use std::io;
7+
use std::pin::Pin;
8+
use std::task::{Context, Poll};
9+
10+
enum TransferState {
11+
Running(CopyBuffer),
12+
ShuttingDown(u64),
13+
Done(u64),
14+
}
15+
16+
struct CopyBidirectional<'a, A: ?Sized, B: ?Sized> {
17+
a: &'a mut A,
18+
b: &'a mut B,
19+
a_to_b: TransferState,
20+
b_to_a: TransferState,
21+
}
22+
23+
fn transfer_one_direction<A, B>(
24+
cx: &mut Context<'_>,
25+
state: &mut TransferState,
26+
r: &mut A,
27+
w: &mut B,
28+
) -> Poll<io::Result<u64>>
29+
where
30+
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
31+
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
32+
{
33+
let mut r = Pin::new(r);
34+
let mut w = Pin::new(w);
35+
36+
loop {
37+
match state {
38+
TransferState::Running(buf) => {
39+
let count = ready!(buf.poll_copy(cx, r.as_mut(), w.as_mut()))?;
40+
*state = TransferState::ShuttingDown(count);
41+
}
42+
TransferState::ShuttingDown(count) => {
43+
ready!(w.as_mut().poll_shutdown(cx))?;
44+
45+
*state = TransferState::Done(*count);
46+
}
47+
TransferState::Done(count) => return Poll::Ready(Ok(*count)),
48+
}
49+
}
50+
}
51+
52+
impl<'a, A, B> Future for CopyBidirectional<'a, A, B>
53+
where
54+
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
55+
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
56+
{
57+
type Output = io::Result<(u64, u64)>;
58+
59+
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
60+
// Unpack self into mut refs to each field to avoid borrow check issues.
61+
let CopyBidirectional {
62+
a,
63+
b,
64+
a_to_b,
65+
b_to_a,
66+
} = &mut *self;
67+
68+
let a_to_b = transfer_one_direction(cx, a_to_b, &mut *a, &mut *b)?;
69+
let b_to_a = transfer_one_direction(cx, b_to_a, &mut *b, &mut *a)?;
70+
71+
// It is not a problem if ready! returns early because transfer_one_direction for the
72+
// other direction will keep returning TransferState::Done(count) in future calls to poll
73+
let a_to_b = ready!(a_to_b);
74+
let b_to_a = ready!(b_to_a);
75+
76+
Poll::Ready(Ok((a_to_b, b_to_a)))
77+
}
78+
}
79+
80+
/// Copies data in both directions between `a` and `b`.
81+
///
82+
/// This function returns a future that will read from both streams,
83+
/// writing any data read to the opposing stream.
84+
/// This happens in both directions concurrently.
85+
///
86+
/// If an EOF is observed on one stream, [`shutdown()`] will be invoked on
87+
/// the other, and reading from that stream will stop. Copying of data in
88+
/// the other direction will continue.
89+
///
90+
/// The future will complete successfully once both directions of communication has been shut down.
91+
/// A direction is shut down when the reader reports EOF,
92+
/// at which point [`shutdown()`] is called on the corresponding writer. When finished,
93+
/// it will return a tuple of the number of bytes copied from a to b
94+
/// and the number of bytes copied from b to a, in that order.
95+
///
96+
/// [`shutdown()`]: crate::io::AsyncWriteExt::shutdown
97+
///
98+
/// # Errors
99+
///
100+
/// The future will immediately return an error if any IO operation on `a`
101+
/// or `b` returns an error. Some data read from either stream may be lost (not
102+
/// written to the other stream) in this case.
103+
///
104+
/// # Return value
105+
///
106+
/// Returns a tuple of bytes copied `a` to `b` and bytes copied `b` to `a`.
107+
pub async fn copy_bidirectional<A, B>(a: &mut A, b: &mut B) -> Result<(u64, u64), std::io::Error>
108+
where
109+
A: AsyncRead + AsyncWrite + Unpin + ?Sized,
110+
B: AsyncRead + AsyncWrite + Unpin + ?Sized,
111+
{
112+
CopyBidirectional {
113+
a,
114+
b,
115+
a_to_b: TransferState::Running(CopyBuffer::new()),
116+
b_to_a: TransferState::Running(CopyBuffer::new()),
117+
}
118+
.await
119+
}

tokio/src/io/util/mod.rs

+3
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,9 @@ cfg_io_util! {
2727
mod copy;
2828
pub use copy::copy;
2929

30+
mod copy_bidirectional;
31+
pub use copy_bidirectional::copy_bidirectional;
32+
3033
mod copy_buf;
3134
pub use copy_buf::copy_buf;
3235

0 commit comments

Comments
 (0)