diff --git a/src/common/mod.rs b/src/common/mod.rs index 664f8f93..442e295d 100644 --- a/src/common/mod.rs +++ b/src/common/mod.rs @@ -118,41 +118,7 @@ where } pub fn write_io(&mut self, cx: &mut Context) -> Poll> { - struct Writer<'a, 'b, T> { - io: &'a mut T, - cx: &'a mut Context<'b>, - } - - impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> { - #[inline] - fn poll_with( - &mut self, - f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll>, - ) -> io::Result { - match f(Pin::new(self.io), self.cx) { - Poll::Ready(result) => result, - Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), - } - } - } - - impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> { - #[inline] - fn write(&mut self, buf: &[u8]) -> io::Result { - self.poll_with(|io, cx| io.poll_write(cx, buf)) - } - - #[inline] - fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { - self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs)) - } - - fn flush(&mut self) -> io::Result<()> { - self.poll_with(|io, cx| io.poll_flush(cx)) - } - } - - let mut writer = Writer { io: self.io, cx }; + let mut writer = SyncWriteAdapter { io: self.io, cx }; match self.session.write_tls(&mut writer) { Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending, @@ -360,5 +326,43 @@ impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> { } } +/// An adapter that implements a [`Write`] interface for [`AsyncWrite`] types and an +/// associated [`Context`]. +/// +/// Turns `Poll::Pending` into `WouldBlock`. +pub struct SyncWriteAdapter<'a, 'b, T> { + pub io: &'a mut T, + pub cx: &'a mut Context<'b>, +} + +impl<'a, 'b, T: Unpin> SyncWriteAdapter<'a, 'b, T> { + #[inline] + fn poll_with( + &mut self, + f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll>, + ) -> io::Result { + match f(Pin::new(self.io), self.cx) { + Poll::Ready(result) => result, + Poll::Pending => Err(io::ErrorKind::WouldBlock.into()), + } + } +} + +impl<'a, 'b, T: AsyncWrite + Unpin> Write for SyncWriteAdapter<'a, 'b, T> { + #[inline] + fn write(&mut self, buf: &[u8]) -> io::Result { + self.poll_with(|io, cx| io.poll_write(cx, buf)) + } + + #[inline] + fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result { + self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs)) + } + + fn flush(&mut self) -> io::Result<()> { + self.poll_with(|io, cx| io.poll_flush(cx)) + } +} + #[cfg(test)] mod test_stream;