Skip to content

Commit

Permalink
Extract SyncWriteAdapter from write_io() body
Browse files Browse the repository at this point in the history
  • Loading branch information
ctz authored and djc committed Mar 5, 2024
1 parent 6f18143 commit 096b161
Showing 1 changed file with 39 additions and 35 deletions.
74 changes: 39 additions & 35 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,41 +118,7 @@ where
}

pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
struct Writer<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>,
}

impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> {
#[inline]
fn poll_with<U>(
&mut self,
f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
) -> io::Result<U> {
match f(Pin::new(self.io), self.cx) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}

impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write(cx, buf))
}

#[inline]
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
}

fn flush(&mut self) -> io::Result<()> {
self.poll_with(|io, cx| io.poll_flush(cx))
}
}

let mut writer = Writer { io: self.io, cx };
let mut writer = SyncWriteAdapter { io: self.io, cx };

match self.session.write_tls(&mut writer) {
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Expand Down Expand Up @@ -360,5 +326,43 @@ impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
}
}

/// An adapter that implements a [`Write`] interface for [`AsyncWrite`] types and an
/// associated [`Context`].
///
/// Turns `Poll::Pending` into `WouldBlock`.
pub struct SyncWriteAdapter<'a, 'b, T> {
pub io: &'a mut T,
pub cx: &'a mut Context<'b>,
}

impl<'a, 'b, T: Unpin> SyncWriteAdapter<'a, 'b, T> {
#[inline]
fn poll_with<U>(
&mut self,
f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
) -> io::Result<U> {
match f(Pin::new(self.io), self.cx) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
}

impl<'a, 'b, T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'a, 'b, T> {
#[inline]
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write(cx, buf))
}

#[inline]
fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
}

fn flush(&mut self) -> io::Result<()> {
self.poll_with(|io, cx| io.poll_flush(cx))
}
}

#[cfg(test)]
mod test_stream;

0 comments on commit 096b161

Please sign in to comment.