Skip to content
This repository has been archived by the owner on Feb 4, 2021. It is now read-only.

Commit

Permalink
fix(codec): Properly decode partial DATA frames (hyperium#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
Phillip Cloud authored and rabbitinspace committed Jan 1, 2020
1 parent 5d0a795 commit 38dc8ab
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 10 deletions.
4 changes: 3 additions & 1 deletion tonic/src/codec/decode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ impl<T> Streaming<T> {
}

if let State::ReadBody { len, .. } = &self.state {
if buf.remaining() < *len {
// if we haven't read enough of the message then return and keep
// reading
if buf.remaining() < *len || self.buf.len() < *len + 5 {
return Ok(None);
}

Expand Down
51 changes: 42 additions & 9 deletions tonic/src/codec/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ struct Msg {
async fn decode() {
let decoder = ProstDecoder::<Msg>::default();

let data = Vec::from(&[0u8; 1024][..]);
let data = vec![0u8; 10000];
let data_len = data.len();
let msg = Msg { data };

let mut buf = BytesMut::new();
Expand All @@ -34,11 +35,20 @@ async fn decode() {
buf.put_u32_be(len as u32);
msg.encode(&mut buf).unwrap();

let body = MockBody(buf.freeze(), 0, 100);
let body = MockBody {
data: buf.freeze(),
partial_len: 10005,
count: 0,
};

let mut stream = Streaming::new_request(decoder, body);

while let Some(_) = stream.message().await.unwrap() {}
let mut i = 0usize;
while let Some(msg) = stream.message().await.unwrap() {
assert_eq!(msg.data.len(), data_len);
i += 1;
}
assert_eq!(i, 1);
}

#[tokio::test]
Expand All @@ -61,20 +71,43 @@ async fn encode() {
}

#[derive(Debug)]
struct MockBody(Bytes, usize, usize);
struct MockBody {
data: Bytes,

// the size of the partial message to send
partial_len: usize,

// the number of times we've sent
count: usize,
}

impl Body for MockBody {
type Data = Data;
type Error = Status;

fn poll_data(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Self::Data, Self::Error>>> {
if self.1 > self.2 {
self.1 += 1;
let data = Data(self.0.clone().into_buf());
Poll::Ready(Some(Ok(data)))
// every other call to poll_data returns data
let should_send = self.count % 2 == 0;
let data_len = self.data.len();
let partial_len = self.partial_len;
let count = self.count;
if data_len > 0 {
let result = if should_send {
let response = self
.data
.split_to(if count == 0 { partial_len } else { data_len })
.into_buf();
Poll::Ready(Some(Ok(Data(response))))
} else {
cx.waker().wake_by_ref();
Poll::Pending
};
// make some fake progress
self.count += 1;
result
} else {
Poll::Ready(None)
}
Expand Down

0 comments on commit 38dc8ab

Please sign in to comment.