diff --git a/src/frame.rs b/src/frame.rs index 0e5b2d1..bdd007c 100644 --- a/src/frame.rs +++ b/src/frame.rs @@ -12,6 +12,7 @@ //! LEN is a 16-bit unsigned integer in big endian byte order. //! +use std::cmp::Ordering; use std::io::Result; use std::io::IoSlice; use std::pin::Pin; @@ -28,6 +29,7 @@ macro_rules! ready { }; } +#[derive(Debug)] enum State { Len, Data(u16), @@ -52,6 +54,7 @@ pub struct UotStream { io: T, rd: State, wr: State, + buf: Vec, } impl UotStream { @@ -61,7 +64,8 @@ impl UotStream { Self { io, rd: State::new(), - wr: State::new(), + wr: State::Data(0), + buf: vec![], } } } @@ -88,42 +92,61 @@ where let this = self.get_mut(); loop { + if !this.buf.is_empty() { + buf.put_slice(this.buf.as_slice()); + }; + + let mut total = buf.filled().len(); + if (matches!(this.rd, State::Len) && total < 2) + || matches!(this.rd, State::Data(len) if total < len as usize) + { + let n = ready!(Pin::new(&mut this.io).poll_read(cx, buf)) + .map(|_| buf.filled().len() - total)?; + // EOF + if n == 0 { + return Poll::Ready(Ok(())); + } + total += n; + } + + // make it immutable + let total = total; + + // we can safely clear the buffer + this.buf.clear(); + match this.rd { State::Len => { - let mut len_be = [0u8; 2]; - let mut len_be_buf = ReadBuf::new(&mut len_be); - let mut total = 0; - - while total < 2 { - let n = ready!(Pin::new(&mut this.io).poll_read(cx, &mut len_be_buf)) - .map(|_| len_be_buf.filled().len() - total)?; - // EOF - if n == 0 { - this.rd = State::Fin; - return Poll::Ready(Ok(())); - } - total += n; + if total < 2 { + this.buf.reserve_exact(total); + this.buf.extend(buf.filled()); + buf.clear(); + continue; } - this.rd = State::Data(u16::from_be_bytes(len_be)); - } - State::Data(len) => { - debug_assert!(len as usize <= buf.remaining()); - let mut buf_limit = ReadBuf::new(buf.initialize_unfilled_to(len as usize)); - let n = ready!(Pin::new(&mut this.io).poll_read(cx, &mut buf_limit)) - .map(|_| buf_limit.filled().len())?; + this.rd = + State::Data(u16::from_be_bytes(buf.filled()[..2].try_into().unwrap())); - buf.advance(n); - if n == 0 { - this.rd = State::Fin; + if total > 2 { + this.buf.reserve_exact(buf.filled().len() - 2); + this.buf.extend(&buf.filled()[2..]); + } + buf.clear(); + } + State::Data(len) => match total.cmp(&(len as usize)) { + Ordering::Equal => { + this.rd = State::Len; return Poll::Ready(Ok(())); - } else if n as u16 == len { + } + Ordering::Less => {} + Ordering::Greater => { + this.buf.reserve_exact(buf.filled()[len as usize..].len()); + this.buf.extend(&buf.filled()[len as usize..]); + buf.set_filled(len as usize); this.rd = State::Len; return Poll::Ready(Ok(())); - } else { - this.rd = State::Data(len - n as u16); - }; - } + } + }, State::Fin => return Poll::Ready(Ok(())), } } @@ -139,53 +162,49 @@ where let this = self.get_mut(); - match this.wr { - State::Len => { - let len_be = (buf.len() as u16).to_be_bytes(); - let mut total = 0; - let mut iovec = &mut [IoSlice::new(&len_be), IoSlice::new(buf)][..]; - loop { - let n = ready!(Pin::new(&mut this.io).poll_write_vectored(cx, iovec))?; - total += n; - // write zero + loop { + match this.wr { + State::Len => { + unreachable!(); + } + State::Data(cursor) => { + let n = if cursor < 2 { + let len_be = &(buf.len() as u16).to_be_bytes()[cursor as usize..]; + let iovec = &mut [IoSlice::new(len_be), IoSlice::new(buf)][..]; + ready!(Pin::new(&mut this.io).poll_write_vectored(cx, iovec))? + } else { + ready!(Pin::new(&mut this.io).poll_write(cx, buf))? + }; + + let written_bytes = if cursor < 2 { + if n + cursor as usize > 2 { + n - (2 - cursor) as usize + } else { + 0 + } + } else { + n + }; + if n == 0 { + // EOF this.wr = State::Fin; return Poll::Ready(Ok(0)); } - // write partial len - #[allow(clippy::comparison_chain)] - if total < 2 { - iovec[0] = IoSlice::new(&len_be[total..]); - continue; - } else if total == 2 { - iovec = &mut iovec[1..]; - continue; + + if written_bytes == buf.len() { + this.wr = State::Data(0); + return Poll::Ready(Ok(written_bytes)); } else { - // write len + data - let write_n = total - 2; - if write_n == buf.len() { - this.wr = State::Len; - } else { - this.wr = State::Data((buf.len() - write_n) as u16); + this.wr = State::Data(n as u16 + cursor); + + if written_bytes != 0 { + return Poll::Ready(Ok(written_bytes)); } - return Poll::Ready(Ok(write_n)); } } + State::Fin => return Poll::Ready(Ok(0)), } - State::Data(len) => { - let n = ready!(Pin::new(&mut this.io).poll_write(cx, &buf[..len as usize]))?; - if n == 0 { - this.wr = State::Fin; - Poll::Ready(Ok(0)) - } else if n < len as usize { - this.wr = State::Data(len - n as u16); - Poll::Ready(Ok(n)) - } else { - this.wr = State::Len; - Poll::Ready(Ok(n)) - } - } - State::Fin => Poll::Ready(Ok(0)), } } @@ -290,7 +309,7 @@ mod test { let mut stream = UotStream::new(SlowStream { buf: Vec::with_capacity(65535), rlimit: 0, - wlimit: wlimit, + wlimit, cursor: 0, }); for i in 1..=512 {