Skip to content

Commit

Permalink
Support no distinct aggregate sum/min/max in `single_distinct_to_grou…
Browse files Browse the repository at this point in the history
…p_by` rule (#8266)

* init impl

* add some tests

* add filter tests

* minor

* add more tests

* update test
  • Loading branch information
haohuaijin authored Nov 26, 2023
1 parent f8dcc64 commit f29bcf3
Show file tree
Hide file tree
Showing 2 changed files with 330 additions and 32 deletions.
280 changes: 248 additions & 32 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule};

use datafusion_common::{DFSchema, Result};
use datafusion_expr::{
aggregate_function::AggregateFunction::{Max, Min, Sum},
col,
expr::AggregateFunction,
logical_plan::{Aggregate, LogicalPlan, Projection},
Expand All @@ -35,17 +36,19 @@ use hashbrown::HashSet;

/// single distinct to group by optimizer rule
/// ```text
/// SELECT F1(DISTINCT s),F2(DISTINCT s)
/// ...
/// GROUP BY k
/// Before:
/// SELECT a, COUNT(DINSTINCT b), SUM(c)
/// FROM t
/// GROUP BY a
///
/// Into
///
/// SELECT F1(alias1),F2(alias1)
/// After:
/// SELECT a, COUNT(alias1), SUM(alias2)
/// FROM (
/// SELECT s as alias1, k ... GROUP BY s, k
/// SELECT a, b as alias1, SUM(c) as alias2
/// FROM t
/// GROUP BY a, b
/// )
/// GROUP BY k
/// GROUP BY a
/// ```
#[derive(Default)]
pub struct SingleDistinctToGroupBy {}
Expand All @@ -64,22 +67,30 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result<bool> {
match plan {
LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => {
let mut fields_set = HashSet::new();
let mut distinct_count = 0;
let mut aggregate_count = 0;
for expr in aggr_expr {
if let Expr::AggregateFunction(AggregateFunction {
distinct, args, ..
fun,
distinct,
args,
filter,
order_by,
}) = expr
{
if *distinct {
distinct_count += 1;
if filter.is_some() || order_by.is_some() {
return Ok(false);
}
for e in args {
fields_set.insert(e.canonical_name());
aggregate_count += 1;
if *distinct {
for e in args {
fields_set.insert(e.canonical_name());
}
} else if !matches!(fun, Sum | Min | Max) {
return Ok(false);
}
}
}
let res = distinct_count == aggr_expr.len() && fields_set.len() == 1;
Ok(res)
Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
}
_ => Ok(false),
}
Expand Down Expand Up @@ -152,30 +163,57 @@ impl OptimizerRule for SingleDistinctToGroupBy {
.collect::<Vec<_>>();

// replace the distinct arg with alias
let mut index = 1;
let mut group_fields_set = HashSet::new();
let new_aggr_exprs = aggr_expr
let mut inner_aggr_exprs = vec![];
let outer_aggr_exprs = aggr_expr
.iter()
.map(|aggr_expr| match aggr_expr {
Expr::AggregateFunction(AggregateFunction {
fun,
args,
filter,
order_by,
distinct,
..
}) => {
// is_single_distinct_agg ensure args.len=1
if group_fields_set.insert(args[0].display_name()?) {
if *distinct
&& group_fields_set.insert(args[0].display_name()?)
{
inner_group_exprs.push(
args[0].clone().alias(SINGLE_DISTINCT_ALIAS),
);
}
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
vec![col(SINGLE_DISTINCT_ALIAS)],
false, // intentional to remove distinct here
filter.clone(),
order_by.clone(),
)))

// if the aggregate function is not distinct, we need to rewrite it like two phase aggregation
if !(*distinct) {
index += 1;
let alias_str = format!("alias{}", index);
inner_aggr_exprs.push(
Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
args.clone(),
false,
None,
None,
))
.alias(&alias_str),
);
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
vec![col(&alias_str)],
false,
None,
None,
)))
} else {
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
vec![col(SINGLE_DISTINCT_ALIAS)],
false, // intentional to remove distinct here
None,
None,
)))
}
}
_ => Ok(aggr_expr.clone()),
})
Expand All @@ -184,6 +222,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
// construct the inner AggrPlan
let inner_fields = inner_group_exprs
.iter()
.chain(inner_aggr_exprs.iter())
.map(|expr| expr.to_field(input.schema()))
.collect::<Result<Vec<_>>>()?;
let inner_schema = DFSchema::new_with_metadata(
Expand All @@ -193,12 +232,12 @@ impl OptimizerRule for SingleDistinctToGroupBy {
let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
input.clone(),
inner_group_exprs,
Vec::new(),
inner_aggr_exprs,
)?);

let outer_fields = outer_group_exprs
.iter()
.chain(new_aggr_exprs.iter())
.chain(outer_aggr_exprs.iter())
.map(|expr| expr.to_field(&inner_schema))
.collect::<Result<Vec<_>>>()?;
let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata(
Expand All @@ -220,7 +259,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
group_expr
}
})
.chain(new_aggr_exprs.iter().enumerate().map(|(idx, expr)| {
.chain(outer_aggr_exprs.iter().enumerate().map(|(idx, expr)| {
let idx = idx + group_size;
let name = fields[idx].qualified_name();
columnize_expr(expr.clone().alias(name), &outer_aggr_schema)
Expand All @@ -230,7 +269,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new(
Arc::new(inner_agg),
outer_group_exprs,
new_aggr_exprs,
outer_aggr_exprs,
)?);

Ok(Some(LogicalPlan::Projection(Projection::try_new(
Expand Down Expand Up @@ -262,7 +301,7 @@ mod tests {
use datafusion_expr::expr::GroupingSet;
use datafusion_expr::{
col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max,
AggregateFunction,
min, sum, AggregateFunction,
};

fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
Expand Down Expand Up @@ -478,4 +517,181 @@ mod tests {

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn two_distinct_and_one_common() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("a")],
vec![
sum(col("c")),
count_distinct(col("b")),
Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Max,
vec![col("b")],
true,
None,
None,
)),
],
)?
.build()?;
// Should work
let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), COUNT(alias1), MAX(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn one_distinctand_and_two_common() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("a")],
vec![sum(col("c")), max(col("c")), count_distinct(col("b"))],
)?
.build()?;
// Should work
let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), MAX(alias3), COUNT(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, MAX(alias3):UInt32;N, COUNT(alias1):Int64;N]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn one_distinct_and_one_common() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("c")],
vec![min(col("a")), count_distinct(col("b"))],
)?
.build()?;
// Should work
let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), COUNT(alias1) AS COUNT(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), COUNT(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, COUNT(alias1):Int64;N]\
\n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn common_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;

// SUM(a) FILTER (WHERE a > 5)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Sum,
vec![col("a")],
false,
Some(Box::new(col("a").gt(lit(5)))),
None,
));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn distinct_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;

// COUNT(DISTINCT a) FILTER (WHERE a > 5)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Count,
vec![col("a")],
true,
Some(Box::new(col("a").gt(lit(5)))),
None,
));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn common_with_order_by() -> Result<()> {
let table_scan = test_table_scan()?;

// SUM(a ORDER BY a)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Sum,
vec![col("a")],
false,
None,
Some(vec![col("a")]),
));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn distinct_with_order_by() -> Result<()> {
let table_scan = test_table_scan()?;

// COUNT(DISTINCT a ORDER BY a)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Count,
vec![col("a")],
true,
None,
Some(vec![col("a")]),
));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn aggregate_with_filter_and_order_by() -> Result<()> {
let table_scan = test_table_scan()?;

// COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Count,
vec![col("a")],
true,
Some(Box::new(col("a").gt(lit(5)))),
Some(vec![col("a")]),
));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}
}
Loading

0 comments on commit f29bcf3

Please sign in to comment.