From ddf4f55778f3b37205b864ffe4f1a50d7db88464 Mon Sep 17 00:00:00 2001 From: Taiki Endo Date: Wed, 21 Aug 2019 10:38:06 +0900 Subject: [PATCH] Add some trait/method implementation to AsyncReadExt::{chain, take} --- futures-util/src/io/chain.rs | 22 +++++++-- futures-util/src/io/mod.rs | 2 +- futures-util/src/io/take.rs | 95 ++++++++++++++++++++++++++---------- 3 files changed, 86 insertions(+), 33 deletions(-) diff --git a/futures-util/src/io/chain.rs b/futures-util/src/io/chain.rs index 9c0b5bf1cd..15758e3a07 100644 --- a/futures-util/src/io/chain.rs +++ b/futures-util/src/io/chain.rs @@ -37,11 +37,6 @@ where } } - /// Consumes the `Chain`, returning the wrapped readers. - pub fn into_inner(self) -> (T, U) { - (self.first, self.second) - } - /// Gets references to the underlying readers in this `Chain`. pub fn get_ref(&self) -> (&T, &U) { (&self.first, &self.second) @@ -55,6 +50,23 @@ where pub fn get_mut(&mut self) -> (&mut T, &mut U) { (&mut self.first, &mut self.second) } + + /// Gets pinned mutable references to the underlying readers in this `Chain`. + /// + /// Care should be taken to avoid modifying the internal I/O state of the + /// underlying readers as doing so may corrupt the internal state of this + /// `Chain`. + pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) { + unsafe { + let Self { first, second, .. } = self.get_unchecked_mut(); + (Pin::new_unchecked(first), Pin::new_unchecked(second)) + } + } + + /// Consumes the `Chain`, returning the wrapped readers. + pub fn into_inner(self) -> (T, U) { + (self.first, self.second) + } } impl fmt::Debug for Chain diff --git a/futures-util/src/io/mod.rs b/futures-util/src/io/mod.rs index 98ea18ee53..d335e11889 100644 --- a/futures-util/src/io/mod.rs +++ b/futures-util/src/io/mod.rs @@ -367,7 +367,7 @@ pub trait AsyncReadExt: AsyncRead { /// # Ok::<(), Box>(()) }).unwrap(); /// ``` fn take(self, limit: u64) -> Take - where Self: Sized + Unpin + where Self: Sized { Take::new(self, limit) } diff --git a/futures-util/src/io/take.rs b/futures-util/src/io/take.rs index c5f0617733..9736250d74 100644 --- a/futures-util/src/io/take.rs +++ b/futures-util/src/io/take.rs @@ -1,21 +1,26 @@ use futures_core::task::{Context, Poll}; -use futures_io::AsyncRead; -use std::io; +use futures_io::{AsyncRead, AsyncBufRead, Initializer}; +use pin_utils::{unsafe_pinned, unsafe_unpinned}; +use std::{cmp, io}; use std::pin::Pin; /// Future for the [`take`](super::AsyncReadExt::take) method. #[derive(Debug)] #[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct Take { +pub struct Take { inner: R, - limit: u64, + // Add '_' to avoid conflicts with `limit` method. + limit_: u64, } impl Unpin for Take { } -impl Take { +impl Take { + unsafe_pinned!(inner: R); + unsafe_unpinned!(limit_: u64); + pub(super) fn new(inner: R, limit: u64) -> Self { - Take { inner, limit } + Self { inner, limit_: limit } } /// Returns the remaining number of bytes that can be @@ -43,7 +48,7 @@ impl Take { /// # Ok::<(), Box>(()) }).unwrap(); /// ``` pub fn limit(&self) -> u64 { - self.limit + self.limit_ } /// Sets the number of bytes that can be read before this instance will @@ -74,10 +79,10 @@ impl Take { /// # Ok::<(), Box>(()) }).unwrap(); /// ``` pub fn set_limit(&mut self, limit: u64) { - self.limit = limit + self.limit_ = limit } - /// Consumes the `Take`, returning the wrapped reader. + /// Gets a reference to the underlying reader. /// /// # Examples /// @@ -92,16 +97,20 @@ impl Take { /// let mut take = reader.take(4); /// let n = take.read(&mut buffer).await?; /// - /// let cursor = take.into_inner(); - /// assert_eq!(cursor.position(), 4); + /// let cursor_ref = take.get_ref(); + /// assert_eq!(cursor_ref.position(), 4); /// /// # Ok::<(), Box>(()) }).unwrap(); /// ``` - pub fn into_inner(self) -> R { - self.inner + pub fn get_ref(&self) -> &R { + &self.inner } - /// Gets a reference to the underlying reader. + /// Gets a mutable reference to the underlying reader. + /// + /// Care should be taken to avoid modifying the internal I/O state of the + /// underlying reader as doing so may corrupt the internal limit of this + /// `Take`. /// /// # Examples /// @@ -116,20 +125,24 @@ impl Take { /// let mut take = reader.take(4); /// let n = take.read(&mut buffer).await?; /// - /// let cursor_ref = take.get_ref(); - /// assert_eq!(cursor_ref.position(), 4); + /// let cursor_mut = take.get_mut(); /// /// # Ok::<(), Box>(()) }).unwrap(); /// ``` - pub fn get_ref(&self) -> &R { - &self.inner + pub fn get_mut(&mut self) -> &mut R { + &mut self.inner } - /// Gets a mutable reference to the underlying reader. + /// Gets a pinned mutable reference to the underlying reader. /// /// Care should be taken to avoid modifying the internal I/O state of the /// underlying reader as doing so may corrupt the internal limit of this /// `Take`. + pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut R> { + self.inner() + } + + /// Consumes the `Take`, returning the wrapped reader. /// /// # Examples /// @@ -144,28 +157,56 @@ impl Take { /// let mut take = reader.take(4); /// let n = take.read(&mut buffer).await?; /// - /// let cursor_mut = take.get_mut(); + /// let cursor = take.into_inner(); + /// assert_eq!(cursor.position(), 4); /// /// # Ok::<(), Box>(()) }).unwrap(); /// ``` - pub fn get_mut(&mut self) -> &mut R { - &mut self.inner + pub fn into_inner(self) -> R { + self.inner } } -impl AsyncRead for Take { +impl AsyncRead for Take { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { - if self.limit == 0 { + if self.limit_ == 0 { return Poll::Ready(Ok(0)); } - let max = std::cmp::min(buf.len() as u64, self.limit) as usize; - let n = ready!(Pin::new(&mut self.inner).poll_read(cx, &mut buf[..max]))?; - self.limit -= n as u64; + let max = std::cmp::min(buf.len() as u64, self.limit_) as usize; + let n = ready!(self.as_mut().inner().poll_read(cx, &mut buf[..max]))?; + *self.as_mut().limit_() -= n as u64; Poll::Ready(Ok(n)) } + + unsafe fn initializer(&self) -> Initializer { + self.inner.initializer() + } +} + +impl AsyncBufRead for Take { + fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let Self { inner, limit_ } = unsafe { self.get_unchecked_mut() }; + let inner = unsafe { Pin::new_unchecked(inner) }; + + // Don't call into inner reader at all at EOF because it may still block + if *limit_ == 0 { + return Poll::Ready(Ok(&[])); + } + + let buf = ready!(inner.poll_fill_buf(cx)?); + let cap = cmp::min(buf.len() as u64, *limit_) as usize; + Poll::Ready(Ok(&buf[..cap])) + } + + fn consume(mut self: Pin<&mut Self>, amt: usize) { + // Don't let callers reset the limit by passing an overlarge value + let amt = cmp::min(amt as u64, self.limit_) as usize; + *self.as_mut().limit_() -= amt as u64; + self.inner().consume(amt); + } }