Skip to content

Commit

Permalink
use PartitionedOutput.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rachelint committed Sep 19, 2024
1 parent 4108065 commit 986ba1a
Showing 1 changed file with 48 additions and 11 deletions.
59 changes: 48 additions & 11 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ pub(crate) enum ExecutionState {
/// here and are sliced off as needed in batch_size chunks
ProducingOutput(RecordBatch),

ProducingPartitionedOutput(Vec<Option<RecordBatch>>),
ProducingPartitionedOutput(PartitionedOutput),
/// Produce intermediate aggregate state for each input row without
/// aggregation.
///
Expand All @@ -78,6 +78,7 @@ pub(crate) enum ExecutionState {
use super::order::GroupOrdering;
use super::AggregateExec;

#[derive(Debug, Clone, Default)]
struct PartitionedOutput {
partitions: Vec<Option<RecordBatch>>,
start_idx: usize,
Expand Down Expand Up @@ -133,7 +134,7 @@ impl PartitionedOutput {
// cut off `batch_size` rows as `output``,
// and set back `remaining`.
let size = self.batch_size;
let num_remaining = batch.num_rows() - size;
let num_remaining = partition_batch.num_rows() - size;
let remaining = partition_batch.slice(size, num_remaining);
let output = partition_batch.slice(0, size);
self.partitions[part_idx] = Some(remaining);
Expand Down Expand Up @@ -718,7 +719,7 @@ impl Stream for GroupedHashAggregateStream {
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();

loop {
match &self.exec_state {
match &mut self.exec_state {
ExecutionState::ReadingInput => 'reading_input: {
match ready!(self.input.poll_next_unpin(cx)) {
// new batch to aggregate
Expand Down Expand Up @@ -800,6 +801,7 @@ impl Stream for GroupedHashAggregateStream {
ExecutionState::ProducingOutput(batch) => {
// slice off a part of the batch, if needed
let output_batch;
let batch = batch.clone();
let size = self.batch_size;
(self.exec_state, output_batch) = if batch.num_rows() <= size {
(
Expand All @@ -810,7 +812,7 @@ impl Stream for GroupedHashAggregateStream {
} else {
ExecutionState::ReadingInput
},
batch.clone(),
batch,
)
} else {
// output first batch_size rows
Expand All @@ -833,7 +835,30 @@ impl Stream for GroupedHashAggregateStream {
return Poll::Ready(None);
}

ExecutionState::ProducingPartitionedOutput(_) => todo!(),
ExecutionState::ProducingPartitionedOutput(parts) => {
// slice off a part of the batch, if needed
let batch_opt = parts.next_batch();
if let Some(batch) = batch_opt {
// output first batch_size rows
let size = self.batch_size;
let num_remaining = batch.num_rows() - size;
let remaining = batch.slice(size, num_remaining);
let output = batch.slice(0, size);
self.exec_state = ExecutionState::ProducingOutput(remaining);

return Poll::Ready(Some(Ok(
output.record_output(&self.baseline_metrics)
)));
} else {
self.exec_state = if self.input_done {
ExecutionState::Done
} else if self.should_skip_aggregation() {
ExecutionState::SkippingAggregation
} else {
ExecutionState::ReadingInput
};
}
}
}
}
}
Expand Down Expand Up @@ -1222,8 +1247,12 @@ impl GroupedHashAggregateStream {
self.exec_state = ExecutionState::ProducingOutput(batch);
} else {
let batches = self.emit(EmitTo::All, false)?;
let batches = batches.into_iter().map(|batch| Some(batch)).collect();
self.exec_state = ExecutionState::ProducingPartitionedOutput(batches);
self.exec_state =
ExecutionState::ProducingPartitionedOutput(PartitionedOutput::new(
batches,
self.batch_size,
self.group_values.num_partitions(),
));
}
}
Ok(())
Expand Down Expand Up @@ -1290,8 +1319,11 @@ impl GroupedHashAggregateStream {
ExecutionState::ProducingOutput(batch)
} else {
let batches = self.emit(EmitTo::All, false)?;
let batches = batches.into_iter().map(|batch| Some(batch)).collect();
ExecutionState::ProducingPartitionedOutput(batches)
ExecutionState::ProducingPartitionedOutput(PartitionedOutput::new(
batches,
self.batch_size,
self.group_values.num_partitions(),
))
}
} else {
// If spill files exist, stream-merge them.
Expand Down Expand Up @@ -1330,8 +1362,13 @@ impl GroupedHashAggregateStream {
self.exec_state = ExecutionState::ProducingOutput(batch);
} else {
let batches = self.emit(EmitTo::All, false)?;
let batches = batches.into_iter().map(|batch| Some(batch)).collect();
self.exec_state = ExecutionState::ProducingPartitionedOutput(batches);
self.exec_state = ExecutionState::ProducingPartitionedOutput(
PartitionedOutput::new(
batches,
self.batch_size,
self.group_values.num_partitions(),
),
);
}
}
}
Expand Down

0 comments on commit 986ba1a

Please sign in to comment.