diff --git a/src/libstd/io/buffered.rs b/src/libstd/io/buffered.rs index cd7a50d07e268..c15a1c8328c03 100644 --- a/src/libstd/io/buffered.rs +++ b/src/libstd/io/buffered.rs @@ -652,6 +652,7 @@ impl fmt::Display for IntoInnerError { #[stable(feature = "rust1", since = "1.0.0")] pub struct LineWriter { inner: BufWriter, + need_flush: bool, } impl LineWriter { @@ -692,7 +693,10 @@ impl LineWriter { /// ``` #[stable(feature = "rust1", since = "1.0.0")] pub fn with_capacity(cap: usize, inner: W) -> LineWriter { - LineWriter { inner: BufWriter::with_capacity(cap, inner) } + LineWriter { + inner: BufWriter::with_capacity(cap, inner), + need_flush: false, + } } /// Gets a reference to the underlying writer. @@ -759,7 +763,10 @@ impl LineWriter { #[stable(feature = "rust1", since = "1.0.0")] pub fn into_inner(self) -> Result>> { self.inner.into_inner().map_err(|IntoInnerError(buf, e)| { - IntoInnerError(LineWriter { inner: buf }, e) + IntoInnerError(LineWriter { + inner: buf, + need_flush: false, + }, e) }) } } @@ -767,20 +774,46 @@ impl LineWriter { #[stable(feature = "rust1", since = "1.0.0")] impl Write for LineWriter { fn write(&mut self, buf: &[u8]) -> io::Result { - match memchr::memrchr(b'\n', buf) { - Some(i) => { - let n = self.inner.write(&buf[..i + 1])?; - if n != i + 1 || self.inner.flush().is_err() { - // Do not return errors on partial writes. - return Ok(n); - } - self.inner.write(&buf[i + 1..]).map(|i| n + i) - } - None => self.inner.write(buf), + if self.need_flush { + self.flush()?; + } + + // Find the last newline character in the buffer provided. If found then + // we're going to write all the data up to that point and then flush, + // otherewise we just write the whole block to the underlying writer. + let i = match memchr::memrchr(b'\n', buf) { + Some(i) => i, + None => return self.inner.write(buf), + }; + + + // Ok, we're going to write a partial amount of the data given first + // followed by flushing the newline. After we've successfully written + // some data then we *must* report that we wrote that data, so future + // errors are ignored. We set our internal `need_flush` flag, though, in + // case flushing fails and we need to try it first next time. + let n = self.inner.write(&buf[..i + 1])?; + self.need_flush = true; + if self.flush().is_err() || n != i + 1 { + return Ok(n) + } + + // At this point we successfully wrote `i + 1` bytes and flushed it out, + // meaning that the entire line is now flushed out on the screen. While + // we can attempt to finish writing the rest of the data provided. + // Remember though that we ignore errors here as we've successfully + // written data, so we need to report that. + match self.inner.write(&buf[i + 1..]) { + Ok(i) => Ok(n + i), + Err(_) => Ok(n), } } - fn flush(&mut self) -> io::Result<()> { self.inner.flush() } + fn flush(&mut self) -> io::Result<()> { + self.inner.flush()?; + self.need_flush = false; + Ok(()) + } } #[stable(feature = "rust1", since = "1.0.0")] @@ -1153,4 +1186,44 @@ mod tests { BufWriter::new(io::sink()) }); } + + struct AcceptOneThenFail { + written: bool, + flushed: bool, + } + + impl Write for AcceptOneThenFail { + fn write(&mut self, data: &[u8]) -> io::Result { + if !self.written { + assert_eq!(data, b"a\nb\n"); + self.written = true; + Ok(data.len()) + } else { + Err(io::Error::new(io::ErrorKind::NotFound, "test")) + } + } + + fn flush(&mut self) -> io::Result<()> { + assert!(self.written); + assert!(!self.flushed); + self.flushed = true; + Err(io::Error::new(io::ErrorKind::Other, "test")) + } + } + + #[test] + fn erroneous_flush_retried() { + let a = AcceptOneThenFail { + written: false, + flushed: false, + }; + + let mut l = LineWriter::new(a); + assert_eq!(l.write(b"a\nb\na").unwrap(), 4); + assert!(l.get_ref().written); + assert!(l.get_ref().flushed); + l.get_mut().flushed = false; + + assert_eq!(l.write(b"a").unwrap_err().kind(), io::ErrorKind::Other) + } }