Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add some trait/method implementation to AsyncReadExt::{chain, take} #1821

Merged
merged 1 commit into from
Aug 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions futures-util/src/io/chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,6 @@ where
}
}

/// Consumes the `Chain`, returning the wrapped readers.
pub fn into_inner(self) -> (T, U) {
(self.first, self.second)
}

/// Gets references to the underlying readers in this `Chain`.
pub fn get_ref(&self) -> (&T, &U) {
(&self.first, &self.second)
Expand All @@ -55,6 +50,23 @@ where
pub fn get_mut(&mut self) -> (&mut T, &mut U) {
(&mut self.first, &mut self.second)
}

/// Gets pinned mutable references to the underlying readers in this `Chain`.
///
/// Care should be taken to avoid modifying the internal I/O state of the
/// underlying readers as doing so may corrupt the internal state of this
/// `Chain`.
pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
unsafe {
let Self { first, second, .. } = self.get_unchecked_mut();
(Pin::new_unchecked(first), Pin::new_unchecked(second))
}
}

/// Consumes the `Chain`, returning the wrapped readers.
pub fn into_inner(self) -> (T, U) {
(self.first, self.second)
}
}

impl<T, U> fmt::Debug for Chain<T, U>
Expand Down
2 changes: 1 addition & 1 deletion futures-util/src/io/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ pub trait AsyncReadExt: AsyncRead {
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
fn take(self, limit: u64) -> Take<Self>
where Self: Sized + Unpin
where Self: Sized
{
Take::new(self, limit)
}
Expand Down
95 changes: 68 additions & 27 deletions futures-util/src/io/take.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,26 @@
use futures_core::task::{Context, Poll};
use futures_io::AsyncRead;
use std::io;
use futures_io::{AsyncRead, AsyncBufRead, Initializer};
use pin_utils::{unsafe_pinned, unsafe_unpinned};
use std::{cmp, io};
use std::pin::Pin;

/// Future for the [`take`](super::AsyncReadExt::take) method.
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Take<R: Unpin> {
pub struct Take<R> {
inner: R,
limit: u64,
// Add '_' to avoid conflicts with `limit` method.
limit_: u64,
}

impl<R: Unpin> Unpin for Take<R> { }

impl<R: AsyncRead + Unpin> Take<R> {
impl<R: AsyncRead> Take<R> {
unsafe_pinned!(inner: R);
unsafe_unpinned!(limit_: u64);

pub(super) fn new(inner: R, limit: u64) -> Self {
Take { inner, limit }
Self { inner, limit_: limit }
}

/// Returns the remaining number of bytes that can be
Expand Down Expand Up @@ -43,7 +48,7 @@ impl<R: AsyncRead + Unpin> Take<R> {
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
pub fn limit(&self) -> u64 {
self.limit
self.limit_
}

/// Sets the number of bytes that can be read before this instance will
Expand Down Expand Up @@ -74,10 +79,10 @@ impl<R: AsyncRead + Unpin> Take<R> {
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
pub fn set_limit(&mut self, limit: u64) {
self.limit = limit
self.limit_ = limit
}

/// Consumes the `Take`, returning the wrapped reader.
/// Gets a reference to the underlying reader.
///
/// # Examples
///
Expand All @@ -92,16 +97,20 @@ impl<R: AsyncRead + Unpin> Take<R> {
/// let mut take = reader.take(4);
/// let n = take.read(&mut buffer).await?;
///
/// let cursor = take.into_inner();
/// assert_eq!(cursor.position(), 4);
/// let cursor_ref = take.get_ref();
/// assert_eq!(cursor_ref.position(), 4);
///
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
pub fn into_inner(self) -> R {
self.inner
pub fn get_ref(&self) -> &R {
&self.inner
}

/// Gets a reference to the underlying reader.
/// Gets a mutable reference to the underlying reader.
///
/// Care should be taken to avoid modifying the internal I/O state of the
/// underlying reader as doing so may corrupt the internal limit of this
/// `Take`.
///
/// # Examples
///
Expand All @@ -116,20 +125,24 @@ impl<R: AsyncRead + Unpin> Take<R> {
/// let mut take = reader.take(4);
/// let n = take.read(&mut buffer).await?;
///
/// let cursor_ref = take.get_ref();
/// assert_eq!(cursor_ref.position(), 4);
/// let cursor_mut = take.get_mut();
///
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
pub fn get_ref(&self) -> &R {
&self.inner
pub fn get_mut(&mut self) -> &mut R {
&mut self.inner
}

/// Gets a mutable reference to the underlying reader.
/// Gets a pinned mutable reference to the underlying reader.
///
/// Care should be taken to avoid modifying the internal I/O state of the
/// underlying reader as doing so may corrupt the internal limit of this
/// `Take`.
pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> {
self.inner()
}

/// Consumes the `Take`, returning the wrapped reader.
///
/// # Examples
///
Expand All @@ -144,28 +157,56 @@ impl<R: AsyncRead + Unpin> Take<R> {
/// let mut take = reader.take(4);
/// let n = take.read(&mut buffer).await?;
///
/// let cursor_mut = take.get_mut();
/// let cursor = take.into_inner();
/// assert_eq!(cursor.position(), 4);
///
/// # Ok::<(), Box<dyn std::error::Error>>(()) }).unwrap();
/// ```
pub fn get_mut(&mut self) -> &mut R {
&mut self.inner
pub fn into_inner(self) -> R {
self.inner
}
}

impl<R: AsyncRead + Unpin> AsyncRead for Take<R> {
impl<R: AsyncRead> AsyncRead for Take<R> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<Result<usize, io::Error>> {
if self.limit == 0 {
if self.limit_ == 0 {
return Poll::Ready(Ok(0));
}

let max = std::cmp::min(buf.len() as u64, self.limit) as usize;
let n = ready!(Pin::new(&mut self.inner).poll_read(cx, &mut buf[..max]))?;
self.limit -= n as u64;
let max = std::cmp::min(buf.len() as u64, self.limit_) as usize;
let n = ready!(self.as_mut().inner().poll_read(cx, &mut buf[..max]))?;
*self.as_mut().limit_() -= n as u64;
Poll::Ready(Ok(n))
}

unsafe fn initializer(&self) -> Initializer {
self.inner.initializer()
}
}

impl<R: AsyncBufRead> AsyncBufRead for Take<R> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
let Self { inner, limit_ } = unsafe { self.get_unchecked_mut() };
let inner = unsafe { Pin::new_unchecked(inner) };

// Don't call into inner reader at all at EOF because it may still block
if *limit_ == 0 {
return Poll::Ready(Ok(&[]));
}

let buf = ready!(inner.poll_fill_buf(cx)?);
let cap = cmp::min(buf.len() as u64, *limit_) as usize;
Poll::Ready(Ok(&buf[..cap]))
}

fn consume(mut self: Pin<&mut Self>, amt: usize) {
// Don't let callers reset the limit by passing an overlarge value
let amt = cmp::min(amt as u64, self.limit_) as usize;
*self.as_mut().limit_() -= amt as u64;
self.inner().consume(amt);
}
}