Skip to content

Commit

Permalink
introduce ExecutionState::ProducingPartitionedOutput to process the…
Browse files Browse the repository at this point in the history
… partitioned outputs.
  • Loading branch information
Rachelint committed Sep 18, 2024
1 parent bec7a3a commit 7214b3e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 12 deletions.
11 changes: 9 additions & 2 deletions datafusion/physical-plan/src/aggregates/group_values/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,16 @@ pub trait GroupValues: Send {
fn clear_shrink(&mut self, batch: &RecordBatch);
}

pub fn new_group_values(schema: SchemaRef, partitioning_group_values: bool, num_partitions: usize) -> Result<GroupValuesLike> {
pub fn new_group_values(
schema: SchemaRef,
partitioning_group_values: bool,
num_partitions: usize,
) -> Result<GroupValuesLike> {
let group_values = if partitioning_group_values && schema.fields.len() > 1 {
GroupValuesLike::Partitioned(Box::new(PartitionedGroupValuesRows::try_new(schema, num_partitions)?))
GroupValuesLike::Partitioned(Box::new(PartitionedGroupValuesRows::try_new(
schema,
num_partitions,
)?))
} else {
GroupValuesLike::Single(new_single_group_values(schema)?)
};
Expand Down
47 changes: 37 additions & 10 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ pub(crate) enum ExecutionState {
/// When producing output, the remaining rows to output are stored
/// here and are sliced off as needed in batch_size chunks
ProducingOutput(RecordBatch),

ProducingPartitionedOutput(Vec<RecordBatch>),
/// Produce intermediate aggregate state for each input row without
/// aggregation.
///
Expand Down Expand Up @@ -677,7 +679,9 @@ impl Stream for GroupedHashAggregateStream {
}

if let Some(to_emit) = self.group_ordering.emit_to() {
let batch = extract_ok!(self.emit(to_emit, false));
let mut batch = extract_ok!(self.emit(to_emit, false));
assert_eq!(batch.len(), 1);
let batch = batch.pop().unwrap();
self.exec_state = ExecutionState::ProducingOutput(batch);
timer.done();
// make sure the exec_state just set is not overwritten below
Expand Down Expand Up @@ -759,6 +763,7 @@ impl Stream for GroupedHashAggregateStream {
let _ = self.update_memory_reservation();
return Poll::Ready(None);
}
ExecutionState::ProducingPartitionedOutput(_) => todo!(),
}
}
}
Expand Down Expand Up @@ -1101,7 +1106,9 @@ impl GroupedHashAggregateStream {

/// Emit all rows, sort them, and store them on disk.
fn spill(&mut self) -> Result<()> {
let emit = self.emit(EmitTo::All, true)?;
let mut emit = self.emit(EmitTo::All, true)?;
assert_eq!(emit.len(), 1);
let emit = emit.pop().unwrap();
let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?;
let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
// TODO: slice large `sorted` and write to multiple files in parallel
Expand Down Expand Up @@ -1138,9 +1145,15 @@ impl GroupedHashAggregateStream {
&& matches!(self.mode, AggregateMode::Partial)
&& self.update_memory_reservation().is_err()
{
let n = self.group_values.len() / self.batch_size * self.batch_size;
let batch = self.emit(EmitTo::First(n), false)?;
self.exec_state = ExecutionState::ProducingOutput(batch);
if !self.group_values.is_partitioned() {
let n = self.group_values.len() / self.batch_size * self.batch_size;
let mut batch = self.emit(EmitTo::First(n), false)?;
let batch = batch.pop().unwrap();
self.exec_state = ExecutionState::ProducingOutput(batch);
} else {
let batches = self.emit(EmitTo::All, false)?;
self.exec_state = ExecutionState::ProducingPartitionedOutput(batches);
}
}
Ok(())
}
Expand All @@ -1150,7 +1163,9 @@ impl GroupedHashAggregateStream {
/// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
/// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
fn update_merged_stream(&mut self) -> Result<()> {
let batch = self.emit(EmitTo::All, true)?;
let mut batch = self.emit(EmitTo::All, true)?;
assert_eq!(batch.len(), 1);
let batch = batch.pop().unwrap();
// clear up memory for streaming_merge
self.clear_all();
self.update_memory_reservation()?;
Expand Down Expand Up @@ -1198,8 +1213,14 @@ impl GroupedHashAggregateStream {
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
let timer = elapsed_compute.timer();
self.exec_state = if self.spill_state.spills.is_empty() {
let batch = self.emit(EmitTo::All, false)?;
ExecutionState::ProducingOutput(batch)
if !self.group_values.is_partitioned() {
let mut batch = self.emit(EmitTo::All, false)?;
let batch = batch.pop().unwrap();
ExecutionState::ProducingOutput(batch)
} else {
let batches = self.emit(EmitTo::All, false)?;
ExecutionState::ProducingPartitionedOutput(batches)
}
} else {
// If spill files exist, stream-merge them.
self.update_merged_stream()?;
Expand Down Expand Up @@ -1231,8 +1252,14 @@ impl GroupedHashAggregateStream {
fn switch_to_skip_aggregation(&mut self) -> Result<()> {
if let Some(probe) = self.skip_aggregation_probe.as_mut() {
if probe.should_skip() {
let batch = self.emit(EmitTo::All, false)?;
self.exec_state = ExecutionState::ProducingOutput(batch);
if !self.group_values.is_partitioned() {
let mut batch = self.emit(EmitTo::All, false)?;
let batch = batch.pop().unwrap();
self.exec_state = ExecutionState::ProducingOutput(batch);
} else {
let batches = self.emit(EmitTo::All, false)?;
self.exec_state = ExecutionState::ProducingPartitionedOutput(batches);
}
}
}

Expand Down

0 comments on commit 7214b3e

Please sign in to comment.