Skip to content

Commit

Permalink
introduce PartitionedOutput.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rachelint committed Sep 19, 2024
1 parent 553c6a3 commit 4108065
Showing 1 changed file with 56 additions and 10 deletions.
66 changes: 56 additions & 10 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

use std::sync::Arc;
use std::task::{Context, Poll};
use std::vec;
use std::{mem, vec};

use crate::aggregates::group_values::{new_group_values, GroupValuesLike};
use crate::aggregates::order::GroupOrderingFull;
Expand Down Expand Up @@ -79,28 +79,74 @@ use super::order::GroupOrdering;
use super::AggregateExec;

struct PartitionedOutput {
batches: Vec<Option<RecordBatch>>,
current_idx: usize,
exhausted: bool
partitions: Vec<Option<RecordBatch>>,
start_idx: usize,
batch_size: usize,
num_partitions: usize,
}

impl PartitionedOutput {
pub fn new(batches: Vec<RecordBatch>) -> Self {
let batches = batches.into_iter().map(|batch| Some(batch)).collect();
pub fn new(
src_batches: Vec<RecordBatch>,
batch_size: usize,
num_partitions: usize,
) -> Self {
let partitions = src_batches.into_iter().map(|batch| Some(batch)).collect();

Self {
batches,
current_idx: 0,
exhausted: false,
partitions,
start_idx: 0,
batch_size,
num_partitions,
}
}

pub fn next_batch(&mut self) -> Option<RecordBatch> {
let mut current_idx = self.start_idx;
loop {
// If found a partition having data,
let batch_opt = if self.partitions[current_idx].is_some() {
Some(self.extract_batch_from_partition(current_idx))
} else {
None
};

// Advance the `current_idx`
current_idx = (current_idx + 1) % self.num_partitions;

if batch_opt.is_some() {
// If found batch, we update the `start_idx` and return it
self.start_idx = current_idx;
return batch_opt;
} else if self.start_idx == current_idx {
// If not found, and has loop to end, we return None
return batch_opt;
}
// Otherwise, we loop to check next partition
}
}

pub fn extract_batch_from_partition(&mut self, part_idx: usize) -> RecordBatch {
let partition_batch = mem::take(&mut self.partitions[part_idx]).unwrap();
if partition_batch.num_rows() > self.batch_size {
// If still the exist rows num > `batch_size`,
// cut off `batch_size` rows as `output``,
// and set back `remaining`.
let size = self.batch_size;
let num_remaining = batch.num_rows() - size;
let remaining = partition_batch.slice(size, num_remaining);
let output = partition_batch.slice(0, size);
self.partitions[part_idx] = Some(remaining);

output
} else {
// If they are the last rows in `partition_batch`, just return,
// because `partition_batch` has been set to `None`.
partition_batch
}
}
}


/// This encapsulates the spilling state
struct SpillState {
// ========================================================================
Expand Down

0 comments on commit 4108065

Please sign in to comment.