Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle edge cases in frame read and write #1

Merged
merged 1 commit into from
May 14, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 87 additions & 68 deletions src/frame.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
//! LEN is a 16-bit unsigned integer in big endian byte order.
//!

use std::cmp::Ordering;
use std::io::Result;
use std::io::IoSlice;
use std::pin::Pin;
Expand All @@ -28,6 +29,7 @@ macro_rules! ready {
};
}

#[derive(Debug)]
enum State {
Len,
Data(u16),
Expand All @@ -52,6 +54,7 @@ pub struct UotStream<T> {
io: T,
rd: State,
wr: State,
buf: Vec<u8>,
}

impl<T> UotStream<T> {
Expand All @@ -61,7 +64,8 @@ impl<T> UotStream<T> {
Self {
io,
rd: State::new(),
wr: State::new(),
wr: State::Data(0),
buf: vec![],
}
}
}
Expand All @@ -88,42 +92,61 @@ where
let this = self.get_mut();

loop {
if !this.buf.is_empty() {
buf.put_slice(this.buf.as_slice());
};

let mut total = buf.filled().len();
if (matches!(this.rd, State::Len) && total < 2)
|| matches!(this.rd, State::Data(len) if total < len as usize)
{
let n = ready!(Pin::new(&mut this.io).poll_read(cx, buf))
.map(|_| buf.filled().len() - total)?;
// EOF
if n == 0 {
return Poll::Ready(Ok(()));
}
total += n;
}

// make it immutable
let total = total;

// we can safely clear the buffer
this.buf.clear();

match this.rd {
State::Len => {
let mut len_be = [0u8; 2];
let mut len_be_buf = ReadBuf::new(&mut len_be);
let mut total = 0;

while total < 2 {
let n = ready!(Pin::new(&mut this.io).poll_read(cx, &mut len_be_buf))
.map(|_| len_be_buf.filled().len() - total)?;
// EOF
if n == 0 {
this.rd = State::Fin;
return Poll::Ready(Ok(()));
}
total += n;
if total < 2 {
this.buf.reserve_exact(total);
this.buf.extend(buf.filled());
buf.clear();
continue;
}

this.rd = State::Data(u16::from_be_bytes(len_be));
}
State::Data(len) => {
debug_assert!(len as usize <= buf.remaining());
let mut buf_limit = ReadBuf::new(buf.initialize_unfilled_to(len as usize));
let n = ready!(Pin::new(&mut this.io).poll_read(cx, &mut buf_limit))
.map(|_| buf_limit.filled().len())?;
this.rd =
State::Data(u16::from_be_bytes(buf.filled()[..2].try_into().unwrap()));

buf.advance(n);
if n == 0 {
this.rd = State::Fin;
if total > 2 {
this.buf.reserve_exact(buf.filled().len() - 2);
this.buf.extend(&buf.filled()[2..]);
}
buf.clear();
}
State::Data(len) => match total.cmp(&(len as usize)) {
Ordering::Equal => {
this.rd = State::Len;
return Poll::Ready(Ok(()));
} else if n as u16 == len {
}
Ordering::Less => {}
Ordering::Greater => {
this.buf.reserve_exact(buf.filled()[len as usize..].len());
this.buf.extend(&buf.filled()[len as usize..]);
buf.set_filled(len as usize);
this.rd = State::Len;
return Poll::Ready(Ok(()));
} else {
this.rd = State::Data(len - n as u16);
};
}
}
},
State::Fin => return Poll::Ready(Ok(())),
}
}
Expand All @@ -139,53 +162,49 @@ where

let this = self.get_mut();

match this.wr {
State::Len => {
let len_be = (buf.len() as u16).to_be_bytes();
let mut total = 0;
let mut iovec = &mut [IoSlice::new(&len_be), IoSlice::new(buf)][..];
loop {
let n = ready!(Pin::new(&mut this.io).poll_write_vectored(cx, iovec))?;
total += n;
// write zero
loop {
match this.wr {
State::Len => {
unreachable!();
}
State::Data(cursor) => {
let n = if cursor < 2 {
let len_be = &(buf.len() as u16).to_be_bytes()[cursor as usize..];
let iovec = &mut [IoSlice::new(len_be), IoSlice::new(buf)][..];
ready!(Pin::new(&mut this.io).poll_write_vectored(cx, iovec))?
} else {
ready!(Pin::new(&mut this.io).poll_write(cx, buf))?
};

let written_bytes = if cursor < 2 {
if n + cursor as usize > 2 {
n - (2 - cursor) as usize
} else {
0
}
} else {
n
};

if n == 0 {
// EOF
this.wr = State::Fin;
return Poll::Ready(Ok(0));
}
// write partial len
#[allow(clippy::comparison_chain)]
if total < 2 {
iovec[0] = IoSlice::new(&len_be[total..]);
continue;
} else if total == 2 {
iovec = &mut iovec[1..];
continue;

if written_bytes == buf.len() {
this.wr = State::Data(0);
return Poll::Ready(Ok(written_bytes));
} else {
// write len + data
let write_n = total - 2;
if write_n == buf.len() {
this.wr = State::Len;
} else {
this.wr = State::Data((buf.len() - write_n) as u16);
this.wr = State::Data(n as u16 + cursor);

if written_bytes != 0 {
return Poll::Ready(Ok(written_bytes));
}
return Poll::Ready(Ok(write_n));
}
}
State::Fin => return Poll::Ready(Ok(0)),
}
State::Data(len) => {
let n = ready!(Pin::new(&mut this.io).poll_write(cx, &buf[..len as usize]))?;
if n == 0 {
this.wr = State::Fin;
Poll::Ready(Ok(0))
} else if n < len as usize {
this.wr = State::Data(len - n as u16);
Poll::Ready(Ok(n))
} else {
this.wr = State::Len;
Poll::Ready(Ok(n))
}
}
State::Fin => Poll::Ready(Ok(0)),
}
}

Expand Down Expand Up @@ -290,7 +309,7 @@ mod test {
let mut stream = UotStream::new(SlowStream {
buf: Vec::with_capacity(65535),
rlimit: 0,
wlimit: wlimit,
wlimit,
cursor: 0,
});
for i in 1..=512 {
Expand Down