diff --git a/tonic/src/codec/decode.rs b/tonic/src/codec/decode.rs index 413afdb45..153a9b6da 100644 --- a/tonic/src/codec/decode.rs +++ b/tonic/src/codec/decode.rs @@ -187,7 +187,9 @@ impl Streaming { } if let State::ReadBody { len, .. } = &self.state { - if buf.remaining() < *len { + // if we haven't read enough of the message then return and keep + // reading + if buf.remaining() < *len || self.buf.len() < *len + 5 { return Ok(None); } diff --git a/tonic/src/codec/tests.rs b/tonic/src/codec/tests.rs index a61d063ab..ac2ca8c2e 100644 --- a/tonic/src/codec/tests.rs +++ b/tonic/src/codec/tests.rs @@ -23,7 +23,8 @@ struct Msg { async fn decode() { let decoder = ProstDecoder::::default(); - let data = Vec::from(&[0u8; 1024][..]); + let data = vec![0u8; 10000]; + let data_len = data.len(); let msg = Msg { data }; let mut buf = BytesMut::new(); @@ -34,11 +35,20 @@ async fn decode() { buf.put_u32_be(len as u32); msg.encode(&mut buf).unwrap(); - let body = MockBody(buf.freeze(), 0, 100); + let body = MockBody { + data: buf.freeze(), + partial_len: 10005, + count: 0, + }; let mut stream = Streaming::new_request(decoder, body); - while let Some(_) = stream.message().await.unwrap() {} + let mut i = 0usize; + while let Some(msg) = stream.message().await.unwrap() { + assert_eq!(msg.data.len(), data_len); + i += 1; + } + assert_eq!(i, 1); } #[tokio::test] @@ -61,7 +71,15 @@ async fn encode() { } #[derive(Debug)] -struct MockBody(Bytes, usize, usize); +struct MockBody { + data: Bytes, + + // the size of the partial message to send + partial_len: usize, + + // the number of times we've sent + count: usize, +} impl Body for MockBody { type Data = Data; @@ -69,12 +87,27 @@ impl Body for MockBody { fn poll_data( mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, ) -> Poll>> { - if self.1 > self.2 { - self.1 += 1; - let data = Data(self.0.clone().into_buf()); - Poll::Ready(Some(Ok(data))) + // every other call to poll_data returns data + let should_send = self.count % 2 == 0; + let data_len = self.data.len(); + let partial_len = self.partial_len; + let count = self.count; + if data_len > 0 { + let result = if should_send { + let response = self + .data + .split_to(if count == 0 { partial_len } else { data_len }) + .into_buf(); + Poll::Ready(Some(Ok(Data(response)))) + } else { + cx.waker().wake_by_ref(); + Poll::Pending + }; + // make some fake progress + self.count += 1; + result } else { Poll::Ready(None) }