diff --git a/ipa-core/src/bin/report_collector.rs b/ipa-core/src/bin/report_collector.rs index 1a0d6755c..9b1d7e84a 100644 --- a/ipa-core/src/bin/report_collector.rs +++ b/ipa-core/src/bin/report_collector.rs @@ -16,8 +16,8 @@ use ipa_core::{ cli::{ playbook::{ make_clients, make_sharded_clients, playbook_oprf_ipa, run_hybrid_query_and_validate, - run_query_and_validate, validate, validate_dp, HybridQueryResult, InputSource, - RoundRobinSubmission, StreamingSubmission, + run_query_and_validate, validate, validate_dp, BufferedRoundRobinSubmission, + HybridQueryResult, InputSource, StreamingSubmission, }, CsvSerializer, IpaQueryResult, Verbosity, }, @@ -370,7 +370,7 @@ fn inputs_from_encrypted_inputs( ] .map(|path| { let file = File::open(path).unwrap_or_else(|e| panic!("unable to open file {path:?}. {e}")); - RoundRobinSubmission::new(BufReader::new(file)) + BufferedRoundRobinSubmission::new(BufReader::new(file)) }) .map(|s| s.into_byte_streams(shard_count)); diff --git a/ipa-core/src/cli/playbook/mod.rs b/ipa-core/src/cli/playbook/mod.rs index fc8b46e9e..46b1645fb 100644 --- a/ipa-core/src/cli/playbook/mod.rs +++ b/ipa-core/src/cli/playbook/mod.rs @@ -22,7 +22,7 @@ use tokio::time::sleep; pub use self::{ hybrid::{run_hybrid_query_and_validate, HybridQueryResult}, ipa::{playbook_oprf_ipa, run_query_and_validate}, - streaming::{RoundRobinSubmission, StreamingSubmission}, + streaming::{BufferedRoundRobinSubmission, StreamingSubmission}, }; use crate::{ cli::config_parse::HelperNetworkConfigParseExt, diff --git a/ipa-core/src/cli/playbook/streaming.rs b/ipa-core/src/cli/playbook/streaming.rs index f246582ee..7ee743aa4 100644 --- a/ipa-core/src/cli/playbook/streaming.rs +++ b/ipa-core/src/cli/playbook/streaming.rs @@ -1,5 +1,6 @@ use std::{ io::BufRead, + num::NonZeroUsize, pin::Pin, task::{Context, Poll, Waker}, }; @@ -9,7 +10,7 @@ use futures::Stream; use crate::{ error::BoxError, - helpers::BytesStream, + helpers::{BufferedBytesStream, BytesStream}, sync::{Arc, Mutex}, }; @@ -20,6 +21,44 @@ pub trait StreamingSubmission { fn into_byte_streams(self, count: usize) -> Vec; } +/// Same as [`RoundRobinSubmission`] but buffers the destination stream +/// until it accumulates at least `buf_size` bytes of data +pub struct BufferedRoundRobinSubmission { + inner: R, + buf_size: NonZeroUsize, +} + +impl BufferedRoundRobinSubmission { + // Standard buffer size for file and network is 8Kb, so we are aligning this value with it. + // Tokio and standard bufer also use 8Kb buffers. + // If other value gives better performance, we should use it instead + const DEFAULT_BUF_SIZE: NonZeroUsize = NonZeroUsize::new(8192).unwrap(); + + /// Create a new instance with the default buffer size. + pub fn new(read_from: R) -> Self { + Self::new_with_buf_size(read_from, Self::DEFAULT_BUF_SIZE) + } + + /// Creates a new instance with the specified buffer size. All streams created + /// using [`StreamingSubmission::into_byte_streams`] will have their own buffer set. + fn new_with_buf_size(read_from: R, buf_size: NonZeroUsize) -> Self { + Self { + inner: read_from, + buf_size, + } + } +} + +impl StreamingSubmission for BufferedRoundRobinSubmission { + fn into_byte_streams(self, count: usize) -> Vec { + RoundRobinSubmission::new(self.inner) + .into_byte_streams(count) + .into_iter() + .map(|s| BufferedBytesStream::new(s, self.buf_size)) + .collect() + } +} + /// Round-Robin strategy to read off the provided buffer /// and distribute them. Inputs is expected to be hex-encoded /// and delimited by newlines. The output streams will have @@ -149,6 +188,7 @@ impl State { #[cfg(all(test, unit_test))] mod tests { use std::{ + collections::HashSet, fs::File, io::{BufReader, Write}, iter, @@ -159,24 +199,98 @@ mod tests { use tempfile::TempDir; use crate::{ - cli::playbook::streaming::{RoundRobinSubmission, StreamingSubmission}, + cli::playbook::streaming::{ + BufferedRoundRobinSubmission, RoundRobinSubmission, StreamingSubmission, + }, helpers::BytesStream, test_executor::run, }; - async fn drain_all(streams: Vec) -> Vec { + async fn drain_all_buffered( + streams: Vec, + buf_size: Option, + ) -> Vec> { let mut futs = FuturesOrdered::default(); for s in streams { - futs.push_back(s.try_fold(String::new(), |mut acc, chunk| async move { - // remove RLE decoding - let len = usize::from(u16::from_le_bytes(chunk[..2].try_into().unwrap())); - assert_eq!(len, chunk.len() - 2); - acc.push_str(&String::from_utf8_lossy(&chunk[2..])); - Ok(acc) - })); + futs.push_back(s.try_fold( + (Vec::new(), HashSet::new(), 0, 0), + |(mut acc, mut sizes, mut leftover, mut pending_len), mut chunk| async move { + // keep track of chunk sizes we've seen from the stream. Only the last chunk + // can have size that is not equal to `buf_size` + sizes.insert(chunk.len()); + + // if we have a leftover from previous buffer, push it first + if leftover > 0 { + let next_chunk = std::cmp::min(leftover, chunk.len()); + leftover -= next_chunk; + acc.extend(&chunk.split_to(next_chunk)); + } + + while !chunk.is_empty() { + // remove RLE decoding + let len = if pending_len > 0 { + // len (2 byte value) can be fragmented as well + let next_byte = + u8::from_le_bytes(chunk.split_to(1).as_ref().try_into().unwrap()); + let r = u16::from_le_bytes([pending_len, next_byte]); + pending_len = 0; + r + } else if chunk.len() > 1 { + let len = + u16::from_le_bytes(chunk.split_to(2).as_ref().try_into().unwrap()); + len + } else { + pending_len = + u8::from_le_bytes(chunk.split_to(1).as_ref().try_into().unwrap()); + assert!(chunk.is_empty()); + break; + }; + + let len = usize::from(len); + + // the next item may span across multiple buffers + let take_len = if len > chunk.len() { + leftover = len - chunk.len(); + chunk.len() + } else { + len + }; + acc.extend(&chunk.split_to(take_len)); + } + + Ok((acc, sizes, leftover, pending_len)) + }, + )); } + futs.try_collect::>() + .await + .unwrap() + .into_iter() + .map(|(s, sizes, leftover, pending_len)| { + assert_eq!(0, leftover); + assert_eq!(0, pending_len); + + // We can have only one chunk that can be at or less than `buf_size`. + // If there are multiple chunks, then at least one must have `buf_size` and there + // can be at most two chunks. + if let Some(buf_size) = buf_size { + assert!(sizes.len() <= 2); + if sizes.len() > 1 { + assert!(sizes.contains(&buf_size)); + } + } - futs.try_collect::>().await.unwrap() + s + }) + .collect() + } + + async fn drain_all(streams: Vec) -> Vec { + drain_all_buffered(streams, None) + .await + .into_iter() + .map(|v| String::from_utf8_lossy(&v).to_string()) + .collect() } fn encoded>>(input: I) -> Vec { @@ -188,6 +302,12 @@ mod tests { run(|| verify_one(vec!["foo", "bar", "baz", "qux", "quux"], 3)); } + #[test] + fn basic_buffered() { + run(|| verify_buffered(vec!["foo", "bar", "baz", "qux", "quux"], 1, 1)); + run(|| verify_buffered(vec!["foo", "bar", "baz", "qux", "quux"], 3, 5)); + } + #[test] #[should_panic(expected = "InvalidHexCharacter")] fn non_hex() { @@ -272,6 +392,24 @@ mod tests { assert_eq!(expected, drain_all(streams).await); } + /// The reason we work with bytes is that string character may span multiple bytes, + /// making [`String::from_utf8`] method work incorrectly as it is not commutative with + /// buffering. + async fn verify_buffered>(input: Vec, count: usize, buf_size: usize) { + assert!(count > 0); + let data = encoded(input.iter().map(AsRef::as_ref)).join("\n"); + let streams = BufferedRoundRobinSubmission::new_with_buf_size( + data.as_bytes(), + buf_size.try_into().unwrap(), + ) + .into_byte_streams(count); + let mut expected: Vec> = vec![vec![]; count]; + for (i, next) in input.into_iter().enumerate() { + expected[i % count].extend(next.as_ref()); + } + assert_eq!(expected, drain_all_buffered(streams, Some(buf_size)).await); + } + proptest! { #[test] fn proptest_round_robin(input: Vec, count in 1_usize..953) { @@ -279,5 +417,12 @@ mod tests { verify_one(input, count).await; }); } + + #[test] + fn proptest_round_robin_buffered(input: Vec>, count in 1_usize..953, buf_size in 1_usize..1024) { + run(move || async move { + verify_buffered(input, count, buf_size).await; + }); + } } } diff --git a/ipa-core/src/helpers/mod.rs b/ipa-core/src/helpers/mod.rs index 2b8e27868..eae6d8439 100644 --- a/ipa-core/src/helpers/mod.rs +++ b/ipa-core/src/helpers/mod.rs @@ -75,11 +75,11 @@ pub use transport::{ InMemoryTransportError, }; pub use transport::{ - make_owned_handler, query, routing, ApiError, BodyStream, BroadcastError, BytesStream, - HandlerBox, HandlerRef, HelperResponse, Identity as TransportIdentity, LengthDelimitedStream, - LogErrors, NoQueryId, NoResourceIdentifier, NoStep, QueryIdBinding, ReceiveRecords, - RecordsStream, RequestHandler, RouteParams, SingleRecordStream, StepBinding, StreamCollection, - StreamKey, Transport, WrappedBoxBodyStream, + make_owned_handler, query, routing, ApiError, BodyStream, BroadcastError, BufferedBytesStream, + BytesStream, HandlerBox, HandlerRef, HelperResponse, Identity as TransportIdentity, + LengthDelimitedStream, LogErrors, NoQueryId, NoResourceIdentifier, NoStep, QueryIdBinding, + ReceiveRecords, RecordsStream, RequestHandler, RouteParams, SingleRecordStream, StepBinding, + StreamCollection, StreamKey, Transport, WrappedBoxBodyStream, }; use typenum::{Const, ToUInt, Unsigned, U8}; use x25519_dalek::PublicKey; diff --git a/ipa-core/src/helpers/transport/mod.rs b/ipa-core/src/helpers/transport/mod.rs index 00021c62c..84289d946 100644 --- a/ipa-core/src/helpers/transport/mod.rs +++ b/ipa-core/src/helpers/transport/mod.rs @@ -35,8 +35,8 @@ pub use receive::{LogErrors, ReceiveRecords}; #[cfg(feature = "web-app")] pub use stream::WrappedAxumBodyStream; pub use stream::{ - BodyStream, BytesStream, LengthDelimitedStream, RecordsStream, SingleRecordStream, - StreamCollection, StreamKey, WrappedBoxBodyStream, + BodyStream, BufferedBytesStream, BytesStream, LengthDelimitedStream, RecordsStream, + SingleRecordStream, StreamCollection, StreamKey, WrappedBoxBodyStream, }; /// An identity of a peer that can be communicated with using [`Transport`]. There are currently two diff --git a/ipa-core/src/helpers/transport/stream/buffered.rs b/ipa-core/src/helpers/transport/stream/buffered.rs index 7efc12112..b0b332fb1 100644 --- a/ipa-core/src/helpers/transport/stream/buffered.rs +++ b/ipa-core/src/helpers/transport/stream/buffered.rs @@ -6,7 +6,7 @@ use std::{ }; use bytes::Bytes; -use futures::Stream; +use futures::{stream::Fuse, Stream, StreamExt}; use pin_project::pin_project; use crate::helpers::BytesStream; @@ -19,7 +19,7 @@ use crate::helpers::BytesStream; pub struct BufferedBytesStream { /// Inner stream to poll #[pin] - inner: S, + inner: Fuse, /// Buffer of bytes pending release buffer: Vec, /// Number of bytes released per single poll. @@ -28,10 +28,10 @@ pub struct BufferedBytesStream { sz: usize, } -impl BufferedBytesStream { - fn new(inner: S, buf_size: NonZeroUsize) -> Self { +impl BufferedBytesStream { + pub fn new(inner: S, buf_size: NonZeroUsize) -> Self { Self { - inner, + inner: inner.fuse(), buffer: Vec::with_capacity(buf_size.get()), sz: buf_size.get(), } diff --git a/ipa-core/src/helpers/transport/stream/mod.rs b/ipa-core/src/helpers/transport/stream/mod.rs index ac39fbb22..9e6e32a63 100644 --- a/ipa-core/src/helpers/transport/stream/mod.rs +++ b/ipa-core/src/helpers/transport/stream/mod.rs @@ -1,7 +1,6 @@ #[cfg(feature = "web-app")] mod axum_body; mod box_body; -#[allow(dead_code)] mod buffered; mod collection; mod input; @@ -14,6 +13,7 @@ use std::{ #[cfg(feature = "web-app")] pub use axum_body::WrappedAxumBodyStream; pub use box_body::WrappedBoxBodyStream; +pub use buffered::BufferedBytesStream; use bytes::Bytes; pub use collection::{StreamCollection, StreamKey}; use futures::{stream::iter, Stream};