Skip to content

Commit

Permalink
implement guarded flush for BufStream
Browse files Browse the repository at this point in the history
closes #122
  • Loading branch information
abonander committed Mar 6, 2020
1 parent 06e184c commit 15369c9
Showing 1 changed file with 42 additions and 7 deletions.
49 changes: 42 additions & 7 deletions sqlx-core/src/io/buf_stream.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
use std::io;
use std::future::Future;
use std::io::{self, BufRead};
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::task::{Context, Poll};

use futures_util::ready;

use crate::runtime::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};

Expand All @@ -20,6 +25,11 @@ pub struct BufStream<S> {
rbuf_windex: usize,
}

pub struct GuardedFlush<'a, S: 'a> {
stream: &'a mut S,
buf: io::Cursor<&'a mut Vec<u8>>,
}

impl<S> BufStream<S>
where
S: AsyncRead + AsyncWrite + Unpin,
Expand All @@ -46,13 +56,12 @@ where
}

#[inline]
pub async fn flush(&mut self) -> io::Result<()> {
if !self.wbuf.is_empty() {
self.stream.write_all(&self.wbuf).await?;
self.wbuf.clear();
#[must_use = "write buffer is cleared on-drop even if future is not polled"]
pub fn flush(&mut self) -> GuardedFlush<S> {
GuardedFlush {
stream: &mut self.stream,
buf: io::Cursor::new(&mut self.wbuf),
}

Ok(())
}

#[inline]
Expand Down Expand Up @@ -156,3 +165,29 @@ macro_rules! ret_if_none {
}
};
}

impl<'a, S: AsyncWrite + Unpin> Future for GuardedFlush<'a, S> {
type Output = io::Result<()>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let buf = self.buf.fill_buf()?;

if !buf.is_empty() {
let written = ready!(self.stream.poll_write(cx)?);
self.buf.consume(written);
} else {
break;
}
}

self.stream.poll_flush(cx)
}
}

impl<'a, S> Drop for GuardedFlush<'a, S> {
fn drop(&mut self) {
// clear the buffer regardless of whether the flush succeeded or not
self.buf.get_mut().clear();
}
}

0 comments on commit 15369c9

Please sign in to comment.