Skip to content

Commit

Permalink
AggregateExec: Take grouping sets into account for InputOrderMode (ap…
Browse files Browse the repository at this point in the history
…ache#11301)

* AggregateExec: Take grouping sets into account for InputOrderMode

* pr comments
# Conflicts:
#	datafusion/physical-plan/src/aggregates/mod.rs
  • Loading branch information
thinkharderdev committed Jul 8, 2024
1 parent 413e08b commit 79c606a
Showing 1 changed file with 117 additions and 9 deletions.
126 changes: 117 additions & 9 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -369,14 +369,26 @@ impl AggregateExec {
new_requirement.extend(req);
new_requirement = collapse_lex_req(new_requirement);

let input_order_mode =
if indices.len() == groupby_exprs.len() && !indices.is_empty() {
InputOrderMode::Sorted
} else if !indices.is_empty() {
InputOrderMode::PartiallySorted(indices)
} else {
InputOrderMode::Linear
};
// If our aggregation has grouping sets then our base grouping exprs will
// be expanded based on the flags in `group_by.groups` where for each
// group we swap the grouping expr for `null` if the flag is `true`
// That means that each index in `indices` is valid if and only if
// it is not null in every group
let indices: Vec<usize> = indices
.into_iter()
.filter(|idx| group_by.groups.iter().all(|group| !group[*idx]))
.collect();

let input_order_mode = if indices.len() == groupby_exprs.len()
&& !indices.is_empty()
&& group_by.groups.len() == 1
{
InputOrderMode::Sorted
} else if !indices.is_empty() {
InputOrderMode::PartiallySorted(indices)
} else {
InputOrderMode::Linear
};

// construct a map from the input expression to the output expression of the Aggregation group by
let projection_mapping =
Expand Down Expand Up @@ -1186,6 +1198,7 @@ mod tests {
use arrow::array::{Float64Array, UInt32Array};
use arrow::compute::{concat_batches, SortOptions};
use arrow::datatypes::DataType;
use arrow_array::{Float32Array, Int32Array};
use datafusion_common::{
assert_batches_eq, assert_batches_sorted_eq, internal_err, DataFusionError,
ScalarValue,
Expand All @@ -1196,10 +1209,12 @@ mod tests {
use datafusion_expr::expr::Sort;
use datafusion_functions_aggregate::median::median_udaf;
use datafusion_physical_expr::expressions::{
lit, ApproxDistinct, Count, FirstValue, LastValue, OrderSensitiveArrayAgg,
lit, ApproxDistinct, Count, FirstValue, LastValue, Literal,
OrderSensitiveArrayAgg,
};
use datafusion_physical_expr::PhysicalSortExpr;

use crate::common::collect;
use futures::{FutureExt, Stream};

// Generate a schema which consists of 5 columns (a, b, c, d, e)
Expand Down Expand Up @@ -2245,4 +2260,97 @@ mod tests {
assert_eq!(new_agg.schema(), aggregate_exec.schema());
Ok(())
}

#[tokio::test]
async fn test_agg_exec_group_by_const() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Float32, true),
Field::new("b", DataType::Float32, true),
Field::new("const", DataType::Int32, false),
]));

let col_a = col("a", &schema)?;
let col_b = col("b", &schema)?;
let const_expr = Arc::new(Literal::new(ScalarValue::Int32(Some(1))));

let groups = PhysicalGroupBy::new(
vec![
(col_a, "a".to_string()),
(col_b, "b".to_string()),
(const_expr, "const".to_string()),
],
vec![
(
Arc::new(Literal::new(ScalarValue::Float32(None))),
"a".to_string(),
),
(
Arc::new(Literal::new(ScalarValue::Float32(None))),
"b".to_string(),
),
(
Arc::new(Literal::new(ScalarValue::Int32(None))),
"const".to_string(),
),
],
vec![
vec![false, true, true],
vec![true, false, true],
vec![true, true, false],
],
);

let aggregates: Vec<Arc<dyn AggregateExpr>> =
vec![Arc::new(Count::new(lit(1), "1", DataType::Int64))];

// let aggregates: Vec<Arc<dyn AggregateExpr>> = vec![create_aggregate_expr(
// count_udaf().as_ref(),
// &[lit(1)],
// &[datafusion_expr::lit(1)],
// &[],
// &[],
// schema.as_ref(),
// "1",
// false,
// false,
// )?];

let input_batches = (0..4)
.map(|_| {
let a = Arc::new(Float32Array::from(vec![0.; 8192]));
let b = Arc::new(Float32Array::from(vec![0.; 8192]));
let c = Arc::new(Int32Array::from(vec![1; 8192]));

RecordBatch::try_new(schema.clone(), vec![a, b, c]).unwrap()
})
.collect();

let input =
Arc::new(MemoryExec::try_new(&[input_batches], schema.clone(), None)?);

let aggregate_exec = Arc::new(AggregateExec::try_new(
AggregateMode::Partial,
groups,
aggregates.clone(),
vec![None],
input,
schema,
)?);

let output =
collect(aggregate_exec.execute(0, Arc::new(TaskContext::default()))?).await?;

let expected = [
"+-----+-----+-------+----------+",
"| a | b | const | 1[count] |",
"+-----+-----+-------+----------+",
"| | 0.0 | | 32768 |",
"| 0.0 | | | 32768 |",
"| | | 1 | 32768 |",
"+-----+-----+-------+----------+",
];
assert_batches_sorted_eq!(expected, &output);

Ok(())
}
}

0 comments on commit 79c606a

Please sign in to comment.