Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ReaderStream #2714

Merged
merged 7 commits into from
Aug 23, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tokio/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ cfg_io_util! {

cfg_stream! {
pub use util::{stream_reader, StreamReader};
pub use util::{reader_stream, ReaderStream};
}
}

Expand Down
3 changes: 3 additions & 0 deletions tokio/src/io/util/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ cfg_io_util! {
cfg_stream! {
mod stream_reader;
pub use stream_reader::{stream_reader, StreamReader};

mod reader_stream;
pub use reader_stream::{reader_stream, ReaderStream};
}

mod take;
Expand Down
94 changes: 94 additions & 0 deletions tokio/src/io/util/reader_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
use crate::io::AsyncRead;
use crate::stream::Stream;
use bytes::{Bytes, BytesMut};
use pin_project_lite::pin_project;
use std::pin::Pin;
use std::task::{Context, Poll};

pin_project! {
/// Convert an [`AsyncRead`](crate::io::AsyncRead) implementor into a
/// [`Stream`](crate::stream::Stream) of Result<[`Bytes`](bytes::Bytes), io::Error>. After first error it will
/// stop.
///
/// This type can be created using the [`reader_stream`](crate::io::reader_stream) function
#[derive(Debug)]
#[cfg_attr(docsrs, doc(cfg(feature = "stream")))]
#[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
pub struct ReaderStream<R> {
// reader itself.
// None if we had error reading from the `reader` in the past.
#[pin]
reader: Option<R>,
// Working buffer, used to optimize allocations.
// # Capacity behavior
// Initially `buf` is empty. Also it's getting smaller and smaller
// during polls (because it's chunks are returned to stream user).
// But when it's capacity reaches 0, it is growed.
buf: BytesMut,
}
}

/// Convert an [`AsyncRead`] implementor into a
/// [`Stream`] of Result<[`Bytes`], std::io::Error>.
///
/// # Example
///
/// ```
/// # #[tokio::main]
/// # async fn main() -> std::io::Result<()> {
/// use tokio::stream::StreamExt;
///
/// let data: &[u8] = b"hello, world!";
/// let mut stream = tokio::io::reader_stream(data);
/// let mut stream_contents = Vec::new();
/// while let Some(chunk) = stream.next().await {
/// stream_contents.extend_from_slice(chunk?.as_ref());
/// }
/// assert_eq!(stream_contents, data);
/// # Ok(())
/// # }
///
/// ```
/// [`AsyncRead`]: crate::io::AsyncRead
/// [`Stream`]: crate::stream::Stream
/// [`Bytes`]: bytes::Bytes
pub fn reader_stream<R>(reader: R) -> ReaderStream<R>
where
R: AsyncRead,
{
ReaderStream {
reader: Some(reader),
buf: BytesMut::new(),
}
}

const CAPACITY: usize = 4096;

impl<R> Stream for ReaderStream<R>
where
R: AsyncRead,
{
type Item = std::io::Result<Bytes>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.as_mut().project();
let reader = match this.reader.as_pin_mut() {
Some(r) => r,
None => return Poll::Ready(None),
};
if this.buf.capacity() == 0 {
this.buf.reserve(CAPACITY);
}
match reader.poll_read_buf(cx, &mut this.buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Err(err)) => {
self.project().reader.set(None);
Poll::Ready(Some(Err(err)))
}
Poll::Ready(Ok(0)) => Poll::Ready(None),
Poll::Ready(Ok(_)) => {
let chunk = this.buf.split();
Poll::Ready(Some(Ok(chunk.freeze())))
}
}
}
}
62 changes: 62 additions & 0 deletions tokio/tests/io_reader_stream.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
#![warn(rust_2018_idioms)]
#![cfg(feature = "full")]

use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::AsyncRead;
use tokio::stream::StreamExt;

/// produces at most `remaining` zeros, that returns error
struct Reader {
remaining: usize,
}

impl AsyncRead for Reader {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<std::io::Result<usize>> {
let this = Pin::into_inner(self);
assert_ne!(buf.len(), 0);
if this.remaining > 0 {
let n = std::cmp::min(this.remaining, buf.len());
for x in &mut buf[..n] {
*x = 0;
}
this.remaining -= n;
Poll::Ready(Ok(n))
} else {
Poll::Ready(Err(std::io::Error::from_raw_os_error(22)))
}
}
}

#[tokio::test]
async fn correct_behavior_on_errors() {
let reader = Reader { remaining: 100 };
let mut stream = tokio::io::reader_stream(reader);
let mut zeros_received = 0;
let mut had_error = false;
loop {
let item = stream.next().await.unwrap();
match item {
Ok(bytes) => {
let bytes = &*bytes;
for byte in bytes {
assert_eq!(*byte, 0);
zeros_received += 1;
}
}
Err(_) => {
assert!(!had_error);
had_error = true;
break;
}
}
}

assert!(had_error);
assert_eq!(zeros_received, 100);
assert!(stream.next().await.is_none());
}