Skip to content

Commit

Permalink
rfc: optional skipping partial aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
korowa committed Jul 28, 2024
1 parent a721be1 commit 9b7e4c8
Show file tree
Hide file tree
Showing 9 changed files with 671 additions and 4 deletions.
9 changes: 9 additions & 0 deletions datafusion/common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,15 @@ config_namespace! {

/// Should DataFusion keep the columns used for partition_by in the output RecordBatches
pub keep_partition_by_columns: bool, default = false

/// Aggregation ratio (number of distinct groups / number of input rows)
/// threshold for skipping partial aggregation. If the value is greater
/// then partial aggregation will skip aggregation for further input
pub skip_partial_aggregation_probe_ratio_threshold: f64, default = 0.8

/// Number of input rows partial aggregation partition should process, before
/// aggregation ratio check and trying to switch to skipping aggregation mode
pub skip_partial_aggregation_probe_rows_threshold: usize, default = 100_000
}
}

Expand Down
20 changes: 19 additions & 1 deletion datafusion/expr/src/groups_accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
//! Vectorized [`GroupsAccumulator`]

use arrow_array::{ArrayRef, BooleanArray};
use datafusion_common::Result;
use datafusion_common::{not_impl_err, Result};

/// Describes how many rows should be emitted during grouping.
#[derive(Debug, Clone, Copy)]
Expand Down Expand Up @@ -158,6 +158,24 @@ pub trait GroupsAccumulator: Send {
total_num_groups: usize,
) -> Result<()>;

/// Converts input batch to intermediate aggregate state,
/// without grouping (each input row considered as a separate
/// group).
fn convert_to_state(
&self,
_values: &[ArrayRef],
_opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
not_impl_err!("Input batch conversion to state not implemented")
}

/// Returns `true` is groups accumulator supports input batch
/// to intermediate aggregate state conversion (`convert_to_state`
/// method is implemented).
fn convert_to_state_supported(&self) -> bool {
false
}

/// Amount of memory used to store the state of this accumulator,
/// in bytes. This function is called once per batch, so it should
/// be `O(n)` to compute, not `O(num_groups)`
Expand Down
45 changes: 44 additions & 1 deletion datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ use arrow::{
};

use arrow::{
array::{Array, BooleanArray, Int64Array, PrimitiveArray},
array::{Array, BooleanArray, Int64Array, Int64Builder, PrimitiveArray},
buffer::BooleanBuffer,
};
use datafusion_common::{
Expand Down Expand Up @@ -433,6 +433,49 @@ impl GroupsAccumulator for CountGroupsAccumulator {
Ok(vec![Arc::new(counts) as ArrayRef])
}

fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let values = &values[0];

let state_array = match (values.logical_nulls(), opt_filter) {
(Some(nulls), None) => {
let mut builder = Int64Builder::with_capacity(values.len());
nulls
.into_iter()
.for_each(|is_valid| builder.append_value(is_valid as i64));
builder.finish()
}
(Some(nulls), Some(filter)) => {
let mut builder = Int64Builder::with_capacity(values.len());
nulls.into_iter().zip(filter.iter()).for_each(
|(is_valid, filter_value)| {
builder.append_value(
(is_valid && filter_value.is_some_and(|val| val)) as i64,
)
},
);
builder.finish()
}
(None, Some(filter)) => {
let mut builder = Int64Builder::with_capacity(values.len());
filter.into_iter().for_each(|filter_value| {
builder.append_value(filter_value.is_some_and(|val| val) as i64)
});
builder.finish()
}
(None, None) => Int64Array::from_value(1, values.len()),
};

Ok(vec![Arc::new(state_array)])
}

fn convert_to_state_supported(&self) -> bool {
true
}

fn size(&self) -> usize {
self.counts.capacity() * std::mem::size_of::<usize>()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use std::sync::Arc;

use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray};
use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray, PrimitiveBuilder};
use arrow::datatypes::ArrowPrimitiveType;
use arrow::datatypes::DataType;
use datafusion_common::Result;
Expand Down Expand Up @@ -134,6 +134,46 @@ where
self.update_batch(values, group_indices, opt_filter, total_num_groups)
}

fn convert_to_state(
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let values = values[0].as_primitive::<T>();
let mut state = PrimitiveBuilder::<T>::with_capacity(values.len())
.with_data_type(self.data_type.clone());

match opt_filter {
Some(filter) => {
values
.iter()
.zip(filter.iter())
.for_each(|(val, filter_val)| match (val, filter_val) {
(Some(val), Some(true)) => {
let mut state_val = self.starting_value;
(self.prim_fn)(&mut state_val, val);
state.append_value(state_val);
}
(_, _) => state.append_null(),
})
}
None => values.iter().for_each(|val| match val {
Some(val) => {
let mut state_val = self.starting_value;
(self.prim_fn)(&mut state_val, val);
state.append_value(state_val);
}
None => state.append_null(),
}),
};

Ok(vec![Arc::new(state.finish())])
}

fn convert_to_state_supported(&self) -> bool {
true
}

fn size(&self) -> usize {
self.values.capacity() * std::mem::size_of::<T::Native>() + self.null_state.size()
}
Expand Down
185 changes: 185 additions & 0 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2398,4 +2398,189 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn test_skip_aggregation_after_first_batch() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::Int32, true),
Field::new("val", DataType::Int32, true),
]));
let df_schema = DFSchema::try_from(Arc::clone(&schema))?;

let group_by =
PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);

let aggr_expr: Vec<Arc<dyn AggregateExpr>> =
vec![create_aggregate_expr_with_dfschema(
&count_udaf(),
&[col("val", &schema)?],
&[datafusion_expr::col("val")],
&[],
&[],
&df_schema,
"COUNT(val)",
false,
false,
false,
)?];

let input_data = vec![
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![0, 0, 0])),
],
)
.unwrap(),
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![2, 3, 4])),
Arc::new(Int32Array::from(vec![0, 0, 0])),
],
)
.unwrap(),
];

let input = Arc::new(MemoryExec::try_new(
&[input_data],
Arc::clone(&schema),
None,
)?);
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
group_by,
aggr_expr,
vec![None],
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
schema,
)?);

let mut session_config = SessionConfig::default();
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
ScalarValue::Int64(Some(2)),
);
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
ScalarValue::Float64(Some(0.1)),
);

let ctx = TaskContext::default().with_session_config(session_config);
let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;

let expected = [
"+-----+-------------------+",
"| key | COUNT(val)[count] |",
"+-----+-------------------+",
"| 1 | 1 |",
"| 2 | 1 |",
"| 3 | 1 |",
"| 2 | 1 |",
"| 3 | 1 |",
"| 4 | 1 |",
"+-----+-------------------+",
];
assert_batches_eq!(expected, &output);

Ok(())
}

#[tokio::test]
async fn test_skip_aggregation_after_threshold() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("key", DataType::Int32, true),
Field::new("val", DataType::Int32, true),
]));
let df_schema = DFSchema::try_from(Arc::clone(&schema))?;

let group_by =
PhysicalGroupBy::new_single(vec![(col("key", &schema)?, "key".to_string())]);

let aggr_expr: Vec<Arc<dyn AggregateExpr>> =
vec![create_aggregate_expr_with_dfschema(
&count_udaf(),
&[col("val", &schema)?],
&[datafusion_expr::col("val")],
&[],
&[],
&df_schema,
"COUNT(val)",
false,
false,
false,
)?];

let input_data = vec![
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![1, 2, 3])),
Arc::new(Int32Array::from(vec![0, 0, 0])),
],
)
.unwrap(),
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![2, 3, 4])),
Arc::new(Int32Array::from(vec![0, 0, 0])),
],
)
.unwrap(),
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(vec![2, 3, 4])),
Arc::new(Int32Array::from(vec![0, 0, 0])),
],
)
.unwrap(),
];

let input = Arc::new(MemoryExec::try_new(
&[input_data],
Arc::clone(&schema),
None,
)?);
let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
group_by,
aggr_expr,
vec![None],
Arc::clone(&input) as Arc<dyn ExecutionPlan>,
schema,
)?);

let mut session_config = SessionConfig::default();
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_rows_threshold",
ScalarValue::Int64(Some(5)),
);
session_config = session_config.set(
"datafusion.execution.skip_partial_aggregation_probe_ratio_threshold",
ScalarValue::Float64(Some(0.1)),
);

let ctx = TaskContext::default().with_session_config(session_config);
let output = collect(aggregate_exec.execute(0, Arc::new(ctx))?).await?;

let expected = [
"+-----+-------------------+",
"| key | COUNT(val)[count] |",
"+-----+-------------------+",
"| 1 | 1 |",
"| 2 | 2 |",
"| 3 | 2 |",
"| 4 | 1 |",
"| 2 | 1 |",
"| 3 | 1 |",
"| 4 | 1 |",
"+-----+-------------------+",
];
assert_batches_eq!(expected, &output);

Ok(())
}
}
Loading

0 comments on commit 9b7e4c8

Please sign in to comment.