Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seal AsyncBufRead and make implementation methods private #2896

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 107 additions & 60 deletions tokio/src/io/async_buf_read.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use crate::io::AsyncRead;

use std::io;
use std::ops::DerefMut;
use std::pin::Pin;
Expand All @@ -20,98 +18,147 @@ use std::task::{Context, Poll};
/// [`poll_fill_buf`]: AsyncBufRead::poll_fill_buf
/// [`BufRead::fill_buf`]: std::io::BufRead::fill_buf
/// [`AsyncBufReadExt`]: crate::io::AsyncBufReadExt
pub trait AsyncBufRead: AsyncRead {
/// Attempts to return the contents of the internal buffer, filling it with more data
/// from the inner reader if it is empty.
///
/// On success, returns `Poll::Ready(Ok(buf))`.
///
/// If no data is available for reading, the method returns
/// `Poll::Pending` and arranges for the current task (via
/// `cx.waker().wake_by_ref()`) to receive a notification when the object becomes
/// readable or is closed.
///
/// This function is a lower-level call. It needs to be paired with the
/// [`consume`] method to function properly. When calling this
/// method, none of the contents will be "read" in the sense that later
/// calling [`poll_read`] may return the same contents. As such, [`consume`] must
/// be called with the number of bytes that are consumed from this buffer to
/// ensure that the bytes are never returned twice.
///
/// An empty buffer returned indicates that the stream has reached EOF.
///
/// [`poll_read`]: AsyncRead::poll_read
/// [`consume`]: AsyncBufRead::consume
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>;

/// Tells this buffer that `amt` bytes have been consumed from the buffer,
/// so they should no longer be returned in calls to [`poll_read`].
///
/// This function is a lower-level call. It needs to be paired with the
/// [`poll_fill_buf`] method to function properly. This function does
/// not perform any I/O, it simply informs this object that some amount of
/// its buffer, returned from [`poll_fill_buf`], has been consumed and should
/// no longer be returned. As such, this function may do odd things if
/// [`poll_fill_buf`] isn't called before calling it.
///
/// The `amt` must be `<=` the number of bytes in the buffer returned by
/// [`poll_fill_buf`].
///
/// [`poll_read`]: AsyncRead::poll_read
/// [`poll_fill_buf`]: AsyncBufRead::poll_fill_buf
fn consume(self: Pin<&mut Self>, amt: usize);
}
pub trait AsyncBufRead: sealed::AsyncBufReadPriv {}

macro_rules! deref_async_buf_read {
() => {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
Pin::new(&mut **self.get_mut()).poll_fill_buf(cx)
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
_: sealed::Internal,
) -> Poll<io::Result<&[u8]>> {
Pin::new(&mut **self.get_mut()).poll_fill_buf(cx, sealed::Internal)
}

fn consume(mut self: Pin<&mut Self>, amt: usize) {
Pin::new(&mut **self).consume(amt)
fn consume(mut self: Pin<&mut Self>, _: sealed::Internal, amt: usize) {
Pin::new(&mut **self).consume(sealed::Internal, amt)
}
};
}

impl<T: ?Sized + AsyncBufRead + Unpin> AsyncBufRead for Box<T> {
impl<T: ?Sized + AsyncBufRead + Unpin> sealed::AsyncBufReadPriv for Box<T> {
deref_async_buf_read!();
}

impl<T: ?Sized + AsyncBufRead + Unpin> AsyncBufRead for &mut T {
impl<T: ?Sized + AsyncBufRead + Unpin> AsyncBufRead for Box<T> {}

impl<T: ?Sized + AsyncBufRead + Unpin> sealed::AsyncBufReadPriv for &mut T {
deref_async_buf_read!();
}

impl<P> AsyncBufRead for Pin<P>
impl<T: ?Sized + AsyncBufRead + Unpin> AsyncBufRead for &mut T {}

impl<P> sealed::AsyncBufReadPriv for Pin<P>
where
P: DerefMut + Unpin,
P::Target: AsyncBufRead,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
self.get_mut().as_mut().poll_fill_buf(cx)
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
_: sealed::Internal,
) -> Poll<io::Result<&[u8]>> {
self.get_mut().as_mut().poll_fill_buf(cx, sealed::Internal)
}

fn consume(self: Pin<&mut Self>, amt: usize) {
self.get_mut().as_mut().consume(amt)
fn consume(self: Pin<&mut Self>, _: sealed::Internal, amt: usize) {
self.get_mut().as_mut().consume(sealed::Internal, amt)
}
}

impl AsyncBufRead for &[u8] {
fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
impl<P> AsyncBufRead for Pin<P>
where
P: DerefMut + Unpin,
P::Target: AsyncBufRead,
{
}

impl sealed::AsyncBufReadPriv for &[u8] {
fn poll_fill_buf(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_: sealed::Internal,
) -> Poll<io::Result<&[u8]>> {
Poll::Ready(Ok(*self))
}

fn consume(mut self: Pin<&mut Self>, amt: usize) {
fn consume(mut self: Pin<&mut Self>, _: sealed::Internal, amt: usize) {
*self = &self[amt..];
}
}

impl<T: AsRef<[u8]> + Unpin> AsyncBufRead for io::Cursor<T> {
fn poll_fill_buf(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
impl AsyncBufRead for &[u8] {}

impl<T: AsRef<[u8]> + Unpin> sealed::AsyncBufReadPriv for io::Cursor<T> {
fn poll_fill_buf(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_: sealed::Internal,
) -> Poll<io::Result<&[u8]>> {
Poll::Ready(io::BufRead::fill_buf(self.get_mut()))
}

fn consume(self: Pin<&mut Self>, amt: usize) {
fn consume(self: Pin<&mut Self>, _: sealed::Internal, amt: usize) {
io::BufRead::consume(self.get_mut(), amt)
}
}

impl<T: AsRef<[u8]> + Unpin> AsyncBufRead for io::Cursor<T> {}

pub(super) mod sealed {
use crate::io::AsyncRead;

use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};

#[doc(hidden)]
pub trait AsyncBufReadPriv: AsyncRead {
/// Attempts to return the contents of the internal buffer, filling it with more data
/// from the inner reader if it is empty.
///
/// On success, returns `Poll::Ready(Ok(buf))`.
///
/// If no data is available for reading, the method returns
/// `Poll::Pending` and arranges for the current task (via
/// `cx.waker().wake_by_ref()`) to receive a notification when the object becomes
/// readable or is closed.
///
/// This function is a lower-level call. It needs to be paired with the
/// [`consume`] method to function properly. When calling this
/// method, none of the contents will be "read" in the sense that later
/// calling [`poll_read`] may return the same contents. As such, [`consume`] must
/// be called with the number of bytes that are consumed from this buffer to
/// ensure that the bytes are never returned twice.
///
/// An empty buffer returned indicates that the stream has reached EOF.
///
/// [`poll_read`]: AsyncRead::poll_read
/// [`consume`]: AsyncBufRead::consume
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
internal: Internal,
) -> Poll<io::Result<&[u8]>>;

/// Tells this buffer that `amt` bytes have been consumed from the buffer,
/// so they should no longer be returned in calls to [`poll_read`].
///
/// This function is a lower-level call. It needs to be paired with the
/// [`poll_fill_buf`] method to function properly. This function does
/// not perform any I/O, it simply informs this object that some amount of
/// its buffer, returned from [`poll_fill_buf`], has been consumed and should
/// no longer be returned. As such, this function may do odd things if
/// [`poll_fill_buf`] isn't called before calling it.
///
/// The `amt` must be `<=` the number of bytes in the buffer returned by
/// [`poll_fill_buf`].
///
/// [`poll_read`]: AsyncRead::poll_read
/// [`poll_fill_buf`]: AsyncBufRead::poll_fill_buf
fn consume(self: Pin<&mut Self>, internal: Internal, amt: usize);
}

#[allow(missing_debug_implementations)]
pub struct Internal;
}
22 changes: 16 additions & 6 deletions tokio/src/io/util/buf_reader.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::io::util::DEFAULT_BUF_SIZE;
use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use crate::io::{async_buf_read, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};

use pin_project_lite::pin_project;
use std::io;
Expand Down Expand Up @@ -109,16 +109,24 @@ impl<R: AsyncRead> AsyncRead for BufReader<R> {
self.discard_buffer();
return Poll::Ready(res);
}
let rem = ready!(self.as_mut().poll_fill_buf(cx))?;
let rem = ready!(AsyncBufRead::poll_fill_buf(
self.as_mut(),
cx,
async_buf_read::sealed::Internal
))?;
let amt = std::cmp::min(rem.len(), buf.remaining());
buf.append(&rem[..amt]);
self.consume(amt);
AsyncBufRead::consume(self, async_buf_read::sealed::Internal, amt);
Poll::Ready(Ok(()))
}
}

impl<R: AsyncRead> AsyncBufRead for BufReader<R> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
impl<R: AsyncRead> async_buf_read::sealed::AsyncBufReadPriv for BufReader<R> {
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
_: async_buf_read::sealed::Internal,
) -> Poll<io::Result<&[u8]>> {
let me = self.project();

// If we've reached the end of our internal buffer then we need to fetch
Expand All @@ -135,12 +143,14 @@ impl<R: AsyncRead> AsyncBufRead for BufReader<R> {
Poll::Ready(Ok(&me.buf[*me.pos..*me.cap]))
}

fn consume(self: Pin<&mut Self>, amt: usize) {
fn consume(self: Pin<&mut Self>, _: async_buf_read::sealed::Internal, amt: usize) {
let me = self.project();
*me.pos = cmp::min(*me.pos + amt, *me.cap);
}
}

impl<R: AsyncRead> AsyncBufRead for BufReader<R> {}

impl<R: AsyncRead + AsyncWrite> AsyncWrite for BufReader<R> {
fn poll_write(
self: Pin<&mut Self>,
Expand Down
18 changes: 12 additions & 6 deletions tokio/src/io/util/buf_stream.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::io::util::{BufReader, BufWriter};
use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use crate::io::{async_buf_read, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};

use pin_project_lite::pin_project;
use std::io;
Expand Down Expand Up @@ -142,16 +142,22 @@ impl<RW: AsyncRead + AsyncWrite> AsyncRead for BufStream<RW> {
}
}

impl<RW: AsyncRead + AsyncWrite> AsyncBufRead for BufStream<RW> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
self.project().inner.poll_fill_buf(cx)
impl<RW: AsyncRead + AsyncWrite> async_buf_read::sealed::AsyncBufReadPriv for BufStream<RW> {
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
internal: async_buf_read::sealed::Internal,
) -> Poll<io::Result<&[u8]>> {
self.project().inner.poll_fill_buf(cx, internal)
}

fn consume(self: Pin<&mut Self>, amt: usize) {
self.project().inner.consume(amt)
fn consume(self: Pin<&mut Self>, internal: async_buf_read::sealed::Internal, amt: usize) {
self.project().inner.consume(internal, amt)
}
}

impl<RW: AsyncRead + AsyncWrite> AsyncBufRead for BufStream<RW> {}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
18 changes: 12 additions & 6 deletions tokio/src/io/util/buf_writer.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::io::util::DEFAULT_BUF_SIZE;
use crate::io::{AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use crate::io::{async_buf_read, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};

use pin_project_lite::pin_project;
use std::fmt;
Expand Down Expand Up @@ -152,16 +152,22 @@ impl<W: AsyncWrite + AsyncRead> AsyncRead for BufWriter<W> {
}
}

impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> {
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
self.get_pin_mut().poll_fill_buf(cx)
impl<W: AsyncWrite + AsyncBufRead> async_buf_read::sealed::AsyncBufReadPriv for BufWriter<W> {
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
internal: async_buf_read::sealed::Internal,
) -> Poll<io::Result<&[u8]>> {
self.get_pin_mut().poll_fill_buf(cx, internal)
}

fn consume(self: Pin<&mut Self>, amt: usize) {
self.get_pin_mut().consume(amt)
fn consume(self: Pin<&mut Self>, internal: async_buf_read::sealed::Internal, amt: usize) {
self.get_pin_mut().consume(internal, amt)
}
}

impl<W: AsyncWrite + AsyncBufRead> AsyncBufRead for BufWriter<W> {}

impl<W: fmt::Debug> fmt::Debug for BufWriter<W> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BufWriter")
Expand Down
31 changes: 23 additions & 8 deletions tokio/src/io/util/chain.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::io::{AsyncBufRead, AsyncRead, ReadBuf};
use crate::io::{async_buf_read, AsyncBufRead, AsyncRead, ReadBuf};

use pin_project_lite::pin_project;
use std::fmt;
Expand Down Expand Up @@ -104,35 +104,50 @@ where
}
}

impl<T, U> AsyncBufRead for Chain<T, U>
impl<T, U> async_buf_read::sealed::AsyncBufReadPriv for Chain<T, U>
where
T: AsyncBufRead,
U: AsyncBufRead,
{
fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
fn poll_fill_buf(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
_: async_buf_read::sealed::Internal,
) -> Poll<io::Result<&[u8]>> {
let me = self.project();

if !*me.done_first {
match ready!(me.first.poll_fill_buf(cx)?) {
match ready!(me
.first
.poll_fill_buf(cx, async_buf_read::sealed::Internal)?)
{
buf if buf.is_empty() => {
*me.done_first = true;
}
buf => return Poll::Ready(Ok(buf)),
}
}
me.second.poll_fill_buf(cx)
me.second
.poll_fill_buf(cx, async_buf_read::sealed::Internal)
}

fn consume(self: Pin<&mut Self>, amt: usize) {
fn consume(self: Pin<&mut Self>, internal: async_buf_read::sealed::Internal, amt: usize) {
let me = self.project();
if !*me.done_first {
me.first.consume(amt)
me.first.consume(internal, amt)
} else {
me.second.consume(amt)
me.second.consume(internal, amt)
}
}
}

impl<T, U> AsyncBufRead for Chain<T, U>
where
T: AsyncBufRead,
U: AsyncBufRead,
{
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
Loading