diff --git a/analytic_engine/src/sst/parquet/builder.rs b/analytic_engine/src/sst/parquet/builder.rs index 33831fb51b..1b9007af91 100644 --- a/analytic_engine/src/sst/parquet/builder.rs +++ b/analytic_engine/src/sst/parquet/builder.rs @@ -2,12 +2,9 @@ //! Sst builder implementation based on parquet. -use std::{ - collections::VecDeque, - sync::{ - atomic::{AtomicUsize, Ordering}, - Arc, - }, +use std::sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, }; use async_trait::async_trait; @@ -67,80 +64,72 @@ impl RecordBytesReader { // Partition record batch stream into batch vector with exactly given // `num_rows_per_row_group` async fn partition_record_batch(&mut self) -> Result<()> { - let mut fetched_row_num = 0; - let mut pending_record_batch: VecDeque = Default::default(); - let mut current_batch = Vec::new(); - let mut remaining = self.num_rows_per_row_group; // how many records are left for current_batch - - while let Some(record_batch) = self.record_stream.next().await { - let record_batch = record_batch.context(PollRecordBatch)?; - - debug_assert!( - !record_batch.is_empty(), - "found empty record batch, request id:{}", - self.request_id - ); - - fetched_row_num += record_batch.num_rows(); - pending_record_batch.push_back(record_batch); - - // reach batch limit, append to self and reset counter and pending batch - // Note: pending_record_batch may contains multiple batches - while fetched_row_num >= self.num_rows_per_row_group { - match pending_record_batch.pop_front() { - // accumulated records is enough for one batch - Some(next) if next.num_rows() >= remaining => { - debug!( - "enough target:{}, remaining:{}, fetched:{}, current:{}", - self.num_rows_per_row_group, - remaining, - fetched_row_num, - next.num_rows(), - ); - current_batch.push(next.slice(0, remaining)); - let left = next.num_rows() - remaining; - if left > 0 { - pending_record_batch.push_front(next.slice(remaining, left)); - } + let mut prev_record_batch: Option = None; - self.partitioned_record_batch - .push(std::mem::take(&mut current_batch)); - fetched_row_num -= self.num_rows_per_row_group; - remaining = self.num_rows_per_row_group; - } - // not enough for one batch - Some(next) => { - debug!( - "not enough target:{}, remaining:{}, fetched:{}, current:{}", - self.num_rows_per_row_group, - remaining, - fetched_row_num, - next.num_rows(), - ); - remaining -= next.num_rows(); - current_batch.push(next); - } - // nothing left, put back to pending_record_batch - _ => { - for records in std::mem::take(&mut current_batch) { - pending_record_batch.push_front(records); - } - break; - } - } + loop { + let row_group = self.fetch_next_row_group(&mut prev_record_batch).await?; + if row_group.is_empty() { + break; } + self.partitioned_record_batch.push(row_group); } - // collect remaining records into one batch - let mut remaining = Vec::with_capacity(pending_record_batch.len()); - while let Some(batch) = pending_record_batch.pop_front() { - remaining.push(batch); - } - if !remaining.is_empty() { - self.partitioned_record_batch.push(remaining); + Ok(()) + } + + /// Fetch an integral row group from the `self.record_stream`. + /// + /// Except the last one, every row group is ensured to contains exactly + /// `self.num_rows_per_row_group`. As for the last one, it will cover all + /// the left rows. + async fn fetch_next_row_group( + &mut self, + prev_record_batch: &mut Option, + ) -> Result> { + let mut curr_row_group = vec![]; + // Used to record the number of remaining rows to fill `curr_row_group`. + let mut remaining = self.num_rows_per_row_group; + + // Keep filling `curr_row_group` until `remaining` is zero. + while remaining > 0 { + // Use the `prev_record_batch` to fill `curr_row_group` if possible. + if let Some(v) = prev_record_batch { + let total_rows = v.num_rows(); + if total_rows <= remaining { + // The whole record batch is part of the `curr_row_group`, and let's feed it + // into `curr_row_group`. + curr_row_group.push(prev_record_batch.take().unwrap()); + remaining -= total_rows; + } else { + // Only first `remaining` rows of the record batch belongs to `curr_row_group`, + // the rest should be put to `prev_record_batch` for next row group. + curr_row_group.push(v.slice(0, remaining)); + *v = v.slice(remaining, total_rows - remaining); + remaining = 0; + } + + continue; + } + + // Previous record batch has been exhausted, and let's fetch next record batch. + match self.record_stream.next().await { + Some(v) => { + let v = v.context(PollRecordBatch)?; + debug_assert!( + !v.is_empty(), + "found empty record batch, request id:{}", + self.request_id + ); + + // Updated the exhausted `prev_record_batch`, and let next loop to continue to + // fill `curr_row_group`. + prev_record_batch.replace(v); + } + None => break, + }; } - Ok(()) + Ok(curr_row_group) } fn build_bloom_filter(&self) -> BloomFilter { @@ -339,15 +328,12 @@ mod tests { bloom_filter: Default::default(), }; - let mut counter = 10; - let record_batch_stream = Box::new(stream::poll_fn(move |ctx| -> Poll> { - counter -= 1; + let mut counter = 5; + let record_batch_stream = Box::new(stream::poll_fn(move |_| -> Poll> { if counter == 0 { return Poll::Ready(None); - } else if counter % 2 == 0 { - ctx.waker().wake_by_ref(); - return Poll::Pending; } + counter -= 1; // reach here when counter is 9 7 5 3 1 let ts = 100 + counter; @@ -433,7 +419,7 @@ mod tests { let mut stream = reader.read().await.unwrap(); let mut expect_rows = vec![]; - for counter in &[9, 7, 5, 3, 1] { + for counter in &[4, 3, 2, 1, 0] { expect_rows.push(build_row(b"a", 100 + counter, 10.0, "v4")); expect_rows.push(build_row(b"b", 100 + counter, 10.0, "v4")); expect_rows.push(build_row(b"c", 100 + counter, 10.0, "v4")); @@ -447,22 +433,26 @@ mod tests { // rows per group: 10 let testcases = vec![ // input, expected - (vec![], vec![]), - (vec![10, 10], vec![10, 10]), - (vec![10, 10, 1], vec![10, 10, 1]), - (vec![10, 10, 21], vec![10, 10, 10, 10, 1]), - (vec![5, 6, 10], vec![10, 10, 1]), - (vec![5, 4, 4, 30], vec![10, 10, 10, 10, 3]), - (vec![20, 7, 23, 20], vec![10, 10, 10, 10, 10, 10, 10]), - (vec![21], vec![10, 10, 1]), + (10, vec![], vec![]), + (10, vec![10, 10], vec![10, 10]), + (10, vec![10, 10, 1], vec![10, 10, 1]), + (10, vec![10, 10, 21], vec![10, 10, 10, 10, 1]), + (10, vec![5, 6, 10], vec![10, 10, 1]), + (10, vec![5, 4, 4, 30], vec![10, 10, 10, 10, 3]), + (10, vec![20, 7, 23, 20], vec![10, 10, 10, 10, 10, 10, 10]), + (10, vec![21], vec![10, 10, 1]), + (10, vec![2, 2, 2, 2, 2], vec![10]), + (4, vec![3, 3, 3, 3, 3], vec![4, 4, 4, 3]), + (5, vec![3, 3, 3, 3, 3], vec![5, 5, 5]), ]; - for (input, expected) in testcases { - test_partition_record_batch_inner(input, expected).await; + for (num_rows_per_group, input, expected) in testcases { + test_partition_record_batch_inner(num_rows_per_group, input, expected).await; } } async fn test_partition_record_batch_inner( + num_rows_per_row_group: usize, input_row_nums: Vec, expected_row_nums: Vec, ) { @@ -488,7 +478,7 @@ mod tests { let mut reader = RecordBytesReader { request_id: RequestId::next_id(), record_stream: record_batch_stream, - num_rows_per_row_group: 10, + num_rows_per_row_group, compression: Compression::UNCOMPRESSED, meta_data: SstMetaData { min_key: Default::default(),