From 283deba789f2d3574edafeafe4937e32cbf86385 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 16 Dec 2024 16:10:05 -0800 Subject: [PATCH 1/3] Create a buffered wrapper around BytesStream The need for it is driven by the behavior we're observing from Report Collector sending bytes down to individual shards. It writes data as it becomes available and Hyper does not accumulate it before sending. On the receiver side we are seeing chunks of size 1 received and that creates thrashing on sender/receiver side. This change paves the path to use buffering on RC side. --- .../src/helpers/transport/stream/buffered.rs | 313 ++++++++++++++++++ ipa-core/src/helpers/transport/stream/mod.rs | 2 + 2 files changed, 315 insertions(+) create mode 100644 ipa-core/src/helpers/transport/stream/buffered.rs diff --git a/ipa-core/src/helpers/transport/stream/buffered.rs b/ipa-core/src/helpers/transport/stream/buffered.rs new file mode 100644 index 000000000..2bbe4a14e --- /dev/null +++ b/ipa-core/src/helpers/transport/stream/buffered.rs @@ -0,0 +1,313 @@ +use std::{ + mem, + num::NonZeroUsize, + pin::Pin, + task::{Context, Poll}, +}; + +use bytes::Bytes; +use futures::Stream; +use pin_project::pin_project; + +use crate::helpers::BytesStream; + +/// An adaptor to buffer items coming from the upstream +/// [`BytesStream`](BytesStream) until the buffer is full, or the upstream is +/// done. This may need to be used when writing into HTTP streams as Hyper +/// does not provide any buffering functionality and we turn NODELAY on +#[pin_project] +pub struct BufferedBytesStream { + /// Inner stream to poll + #[pin] + inner: S, + /// Buffer of bytes pending release + buffer: Vec, + /// Number of bytes released per single poll. + /// All items except the last one are guaranteed to have + /// exactly this number of bytes written to them. + sz: usize, +} + +impl BufferedBytesStream { + fn new(inner: S, buf_size: NonZeroUsize) -> Self { + Self { + inner, + buffer: Vec::with_capacity(buf_size.get()), + sz: buf_size.get(), + } + } +} + +impl Stream for BufferedBytesStream { + type Item = S::Item; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + fn take_next(buf: &mut Vec) -> Vec { + mem::replace(buf, Vec::with_capacity(buf.len())) + } + + let mut this = self.as_mut().project(); + loop { + // If we are at capacity, return what we have + if this.buffer.len() >= *this.sz { + // if we have more than we need in the buffer, split it + // otherwise, return the whole buffer to the reader + let next = if this.buffer.len() > *this.sz { + this.buffer.drain(..*this.sz).collect() + } else { + take_next(this.buffer) + }; + break Poll::Ready(Some(Ok(Bytes::from(next)))); + } + + match this.inner.as_mut().poll_next(cx) { + Poll::Ready(Some(item)) => { + // Received next portion of data, buffer it + match item { + Ok(bytes) => { + this.buffer.extend(bytes); + } + Err(e) => { + break Poll::Ready(Some(Err(e))); + } + } + } + Poll::Ready(None) => { + // yield what we have because the upstream is done + let next = if this.buffer.is_empty() { + None + } else { + Some(Ok(Bytes::from(take_next(this.buffer)))) + }; + + break Poll::Ready(next); + } + Poll::Pending => { + // we don't have enough data in the buffer (otherwise we wouldn't be here) + break Poll::Pending; + } + } + } + } +} + +#[cfg(all(test, unit_test))] +mod tests { + use std::{ + cmp::min, + mem, + num::NonZeroUsize, + pin::Pin, + sync::{Arc, Mutex}, + task::Poll, + }; + + use bytes::Bytes; + use futures::{stream::TryStreamExt, FutureExt, Stream, StreamExt}; + use proptest::{ + prop_compose, proptest, + strategy::{Just, Strategy}, + }; + + use crate::{ + error::BoxError, helpers::transport::stream::buffered::BufferedBytesStream, + test_executor::run, + }; + + #[test] + fn success() { + run(|| async move { + verify_success(infallible_stream(11, 2), 3).await; + // verify_success(infallible_stream(12, 3), 3).await; + // verify_success(infallible_stream(12, 5), 12).await; + // verify_success(infallible_stream(12, 12), 12).await; + // verify_success(infallible_stream(24, 12), 12).await; + // verify_success(infallible_stream(24, 12), 1).await; + }); + } + + #[test] + fn fails_on_first_error() { + run(|| async move { + let stream = fallible_stream(12, 3, 5); + let mut buffered = BufferedBytesStream::new(stream, NonZeroUsize::try_from(2).unwrap()); + let mut buf = Vec::new(); + while let Some(next) = buffered.next().await { + match next { + Ok(bytes) => { + assert_eq!(2, bytes.len()); + buf.extend(bytes); + } + Err(_) => { + break; + } + } + } + + // we could only receive 2 bytes from the stream and here is why. + // first read puts 3 bytes into the buffer and we take 2 bytes off it. + // second read does not have sufficient bytes in the buffer, and we need + // to read from the stream again. Next read results in an error and we + // return it immediately + assert_eq!(2, buf.len()); + }); + } + + #[test] + fn pending() { + let status = Arc::new(Mutex::new(vec![1, 2])); + let stream = futures::stream::poll_fn({ + let status = Arc::clone(&status); + move |_cx| { + let mut vec = status.lock().unwrap(); + if vec.is_empty() { + Poll::Pending + } else { + Poll::Ready(Some(Ok(Bytes::from(mem::take(&mut *vec))))) + } + } + }); + + let mut buffered = BufferedBytesStream::new(stream, NonZeroUsize::try_from(4).unwrap()); + let mut fut = std::pin::pin!(buffered.next()); + assert!(fut.as_mut().now_or_never().is_none()); + + status.lock().unwrap().extend([3, 4]); + let actual = fut.now_or_never().flatten().unwrap().unwrap(); + assert_eq!(Bytes::from(vec![1, 2, 3, 4]), actual); + } + + async fn verify_success(input: TestStream, chunk_size: usize) { + let total_size = input.total_size; + assert!(total_size >= chunk_size); + let expected = input.clone(); + let mut buffered = BufferedBytesStream::new(input, chunk_size.try_into().unwrap()); + + let mut last_chunk_size = None; + let mut actual = Vec::new(); + while let Ok(Some(bytes)) = buffered.try_next().await { + assert!(bytes.len() <= chunk_size); + // All chunks except the last one must be exactly of `chunk_size` size. + if let Some(last) = last_chunk_size { + assert_eq!(last, chunk_size); + } + last_chunk_size = Some(bytes.len()); + actual.extend(bytes); + } + + // compare with what the original stream returned + assert_eq!(actual.len(), total_size); + let expected = expected + .try_collect::>() + .await + .unwrap() + .into_iter() + .flatten() + .collect::>(); + assert_eq!(expected, actual); + } + + #[derive(Debug, Clone)] + struct TestStream { + total_size: usize, + remaining: usize, + chunk: usize, + } + + struct FallibleTestStream { + total_size: usize, + remaining: usize, + chunk: usize, + error_after: usize, + } + + fn infallible_stream(total_size: usize, chunk: usize) -> TestStream { + TestStream { + total_size, + remaining: total_size, + chunk, + } + } + + fn fallible_stream(total_size: usize, chunk: usize, error_after: usize) -> FallibleTestStream { + FallibleTestStream { + total_size, + remaining: total_size, + chunk, + error_after, + } + } + + impl Stream for TestStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if self.remaining == 0 { + return Poll::Ready(None); + } + let next_chunk_size = min(self.remaining, self.chunk); + let next_chunk = (0..next_chunk_size) + .map(|v| u8::try_from(v % 256).unwrap()) + .collect::>(); + + self.remaining -= next_chunk_size; + Poll::Ready(Some(Ok(Bytes::from(next_chunk)))) + } + } + + impl Stream for FallibleTestStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + if self.remaining == 0 { + return Poll::Ready(None); + } + let next_chunk_size = min(self.remaining, self.chunk); + let next_chunk = (0..next_chunk_size) + .map(|v| u8::try_from(v % 256).unwrap()) + .collect::>(); + + self.remaining -= next_chunk_size; + if self.total_size - self.remaining >= self.error_after { + Poll::Ready(Some(Err("error".into()))) + } else { + Poll::Ready(Some(Ok(Bytes::from(next_chunk)))) + } + } + } + + prop_compose! { + fn arb_infallible_stream(max_size: u16) + (total_size in 1..max_size) + (total_size in Just(total_size), chunk in 1..total_size) + -> TestStream { + TestStream { + total_size: total_size as usize, + remaining: total_size as usize, + chunk: chunk as usize, + } + } + } + + fn stream_and_chunk() -> impl Strategy { + arb_infallible_stream(24231).prop_flat_map(|stream| { + let len = stream.total_size; + (Just(stream), 1..len) + }) + } + + proptest! { + #[test] + fn proptest_success((stream, chunk) in stream_and_chunk()) { + run(move || async move { + verify_success(stream, chunk).await; + }); + } + } +} diff --git a/ipa-core/src/helpers/transport/stream/mod.rs b/ipa-core/src/helpers/transport/stream/mod.rs index 5925b62f5..ac39fbb22 100644 --- a/ipa-core/src/helpers/transport/stream/mod.rs +++ b/ipa-core/src/helpers/transport/stream/mod.rs @@ -1,6 +1,8 @@ #[cfg(feature = "web-app")] mod axum_body; mod box_body; +#[allow(dead_code)] +mod buffered; mod collection; mod input; From 9a0cfc8653822065d67ac1670239b60ffee44341 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 16 Dec 2024 17:40:21 -0800 Subject: [PATCH 2/3] Uncomment the code --- ipa-core/src/helpers/transport/stream/buffered.rs | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ipa-core/src/helpers/transport/stream/buffered.rs b/ipa-core/src/helpers/transport/stream/buffered.rs index 2bbe4a14e..8f68f916c 100644 --- a/ipa-core/src/helpers/transport/stream/buffered.rs +++ b/ipa-core/src/helpers/transport/stream/buffered.rs @@ -118,11 +118,11 @@ mod tests { fn success() { run(|| async move { verify_success(infallible_stream(11, 2), 3).await; - // verify_success(infallible_stream(12, 3), 3).await; - // verify_success(infallible_stream(12, 5), 12).await; - // verify_success(infallible_stream(12, 12), 12).await; - // verify_success(infallible_stream(24, 12), 12).await; - // verify_success(infallible_stream(24, 12), 1).await; + verify_success(infallible_stream(12, 3), 3).await; + verify_success(infallible_stream(12, 5), 12).await; + verify_success(infallible_stream(12, 12), 12).await; + verify_success(infallible_stream(24, 12), 12).await; + verify_success(infallible_stream(24, 12), 1).await; }); } From 0b8254f3ad04f5f9e84664ef8db3799ad29267c2 Mon Sep 17 00:00:00 2001 From: Alex Koshelev Date: Mon, 16 Dec 2024 18:43:19 -0800 Subject: [PATCH 3/3] Feedback --- .../src/helpers/transport/stream/buffered.rs | 52 +++++++++---------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/ipa-core/src/helpers/transport/stream/buffered.rs b/ipa-core/src/helpers/transport/stream/buffered.rs index 8f68f916c..7efc12112 100644 --- a/ipa-core/src/helpers/transport/stream/buffered.rs +++ b/ipa-core/src/helpers/transport/stream/buffered.rs @@ -99,15 +99,18 @@ mod tests { num::NonZeroUsize, pin::Pin, sync::{Arc, Mutex}, + task, task::Poll, }; use bytes::Bytes; use futures::{stream::TryStreamExt, FutureExt, Stream, StreamExt}; + use pin_project::pin_project; use proptest::{ prop_compose, proptest, strategy::{Just, Strategy}, }; + use task::Context; use crate::{ error::BoxError, helpers::transport::stream::buffered::BufferedBytesStream, @@ -214,10 +217,10 @@ mod tests { chunk: usize, } + #[pin_project] struct FallibleTestStream { - total_size: usize, - remaining: usize, - chunk: usize, + #[pin] + inner: TestStream, error_after: usize, } @@ -231,9 +234,11 @@ mod tests { fn fallible_stream(total_size: usize, chunk: usize, error_after: usize) -> FallibleTestStream { FallibleTestStream { - total_size, - remaining: total_size, - chunk, + inner: TestStream { + total_size, + remaining: total_size, + chunk, + }, error_after, } } @@ -241,10 +246,7 @@ mod tests { impl Stream for TestStream { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { + fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { if self.remaining == 0 { return Poll::Ready(None); } @@ -261,23 +263,19 @@ mod tests { impl Stream for FallibleTestStream { type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - if self.remaining == 0 { - return Poll::Ready(None); - } - let next_chunk_size = min(self.remaining, self.chunk); - let next_chunk = (0..next_chunk_size) - .map(|v| u8::try_from(v % 256).unwrap()) - .collect::>(); - - self.remaining -= next_chunk_size; - if self.total_size - self.remaining >= self.error_after { - Poll::Ready(Some(Err("error".into()))) - } else { - Poll::Ready(Some(Ok(Bytes::from(next_chunk)))) + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + match this.inner.as_mut().poll_next(cx) { + Poll::Ready(Some(Ok(bytes))) => { + if this.inner.total_size - this.inner.remaining >= *this.error_after { + Poll::Ready(Some(Err("error".into()))) + } else { + Poll::Ready(Some(Ok(bytes))) + } + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))), + Poll::Ready(None) => Poll::Ready(None), + Poll::Pending => Poll::Pending, } } }