Skip to content

Support no distinct aggregate sum/min/max in single_distinct_to_group_by rule #8266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 248 additions & 32 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule};

use datafusion_common::{DFSchema, Result};
use datafusion_expr::{
aggregate_function::AggregateFunction::{Max, Min, Sum},
col,
expr::AggregateFunction,
logical_plan::{Aggregate, LogicalPlan, Projection},
Expand All @@ -35,17 +36,19 @@ use hashbrown::HashSet;

/// single distinct to group by optimizer rule
/// ```text
/// SELECT F1(DISTINCT s),F2(DISTINCT s)
/// ...
/// GROUP BY k
/// Before:
/// SELECT a, COUNT(DINSTINCT b), SUM(c)
/// FROM t
/// GROUP BY a
///
/// Into
///
/// SELECT F1(alias1),F2(alias1)
/// After:
/// SELECT a, COUNT(alias1), SUM(alias2)
/// FROM (
/// SELECT s as alias1, k ... GROUP BY s, k
/// SELECT a, b as alias1, SUM(c) as alias2
/// FROM t
/// GROUP BY a, b
/// )
/// GROUP BY k
/// GROUP BY a
/// ```
#[derive(Default)]
pub struct SingleDistinctToGroupBy {}
Expand All @@ -64,22 +67,30 @@ fn is_single_distinct_agg(plan: &LogicalPlan) -> Result<bool> {
match plan {
LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => {
let mut fields_set = HashSet::new();
let mut distinct_count = 0;
let mut aggregate_count = 0;
for expr in aggr_expr {
if let Expr::AggregateFunction(AggregateFunction {
distinct, args, ..
fun,
distinct,
args,
filter,
order_by,
}) = expr
{
if *distinct {
distinct_count += 1;
if filter.is_some() || order_by.is_some() {
return Ok(false);
}
for e in args {
fields_set.insert(e.canonical_name());
aggregate_count += 1;
if *distinct {
for e in args {
fields_set.insert(e.canonical_name());
}
} else if !matches!(fun, Sum | Min | Max) {
return Ok(false);
}
}
}
let res = distinct_count == aggr_expr.len() && fields_set.len() == 1;
Ok(res)
Ok(aggregate_count == aggr_expr.len() && fields_set.len() == 1)
}
_ => Ok(false),
}
Expand Down Expand Up @@ -152,30 +163,57 @@ impl OptimizerRule for SingleDistinctToGroupBy {
.collect::<Vec<_>>();

// replace the distinct arg with alias
let mut index = 1;
let mut group_fields_set = HashSet::new();
let new_aggr_exprs = aggr_expr
let mut inner_aggr_exprs = vec![];
let outer_aggr_exprs = aggr_expr
.iter()
.map(|aggr_expr| match aggr_expr {
Expr::AggregateFunction(AggregateFunction {
fun,
args,
filter,
order_by,
distinct,
..
}) => {
// is_single_distinct_agg ensure args.len=1
if group_fields_set.insert(args[0].display_name()?) {
if *distinct
&& group_fields_set.insert(args[0].display_name()?)
{
inner_group_exprs.push(
args[0].clone().alias(SINGLE_DISTINCT_ALIAS),
);
}
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
vec![col(SINGLE_DISTINCT_ALIAS)],
false, // intentional to remove distinct here
filter.clone(),
order_by.clone(),
)))

// if the aggregate function is not distinct, we need to rewrite it like two phase aggregation
if !(*distinct) {
index += 1;
let alias_str = format!("alias{}", index);
inner_aggr_exprs.push(
Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
args.clone(),
false,
None,
None,
))
.alias(&alias_str),
);
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
vec![col(&alias_str)],
false,
None,
None,
)))
} else {
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
vec![col(SINGLE_DISTINCT_ALIAS)],
false, // intentional to remove distinct here
None,
None,
)))
}
}
_ => Ok(aggr_expr.clone()),
})
Expand All @@ -184,6 +222,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
// construct the inner AggrPlan
let inner_fields = inner_group_exprs
.iter()
.chain(inner_aggr_exprs.iter())
.map(|expr| expr.to_field(input.schema()))
.collect::<Result<Vec<_>>>()?;
let inner_schema = DFSchema::new_with_metadata(
Expand All @@ -193,12 +232,12 @@ impl OptimizerRule for SingleDistinctToGroupBy {
let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
input.clone(),
inner_group_exprs,
Vec::new(),
inner_aggr_exprs,
)?);

let outer_fields = outer_group_exprs
.iter()
.chain(new_aggr_exprs.iter())
.chain(outer_aggr_exprs.iter())
.map(|expr| expr.to_field(&inner_schema))
.collect::<Result<Vec<_>>>()?;
let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata(
Expand All @@ -220,7 +259,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
group_expr
}
})
.chain(new_aggr_exprs.iter().enumerate().map(|(idx, expr)| {
.chain(outer_aggr_exprs.iter().enumerate().map(|(idx, expr)| {
let idx = idx + group_size;
let name = fields[idx].qualified_name();
columnize_expr(expr.clone().alias(name), &outer_aggr_schema)
Expand All @@ -230,7 +269,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
let outer_aggr = LogicalPlan::Aggregate(Aggregate::try_new(
Arc::new(inner_agg),
outer_group_exprs,
new_aggr_exprs,
outer_aggr_exprs,
)?);

Ok(Some(LogicalPlan::Projection(Projection::try_new(
Expand Down Expand Up @@ -262,7 +301,7 @@ mod tests {
use datafusion_expr::expr::GroupingSet;
use datafusion_expr::{
col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max,
AggregateFunction,
min, sum, AggregateFunction,
};

fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> {
Expand Down Expand Up @@ -478,4 +517,181 @@ mod tests {

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn two_distinct_and_one_common() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("a")],
vec![
sum(col("c")),
count_distinct(col("b")),
Expr::AggregateFunction(expr::AggregateFunction::new(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@haohuaijin haohuaijin Nov 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is because we only have count_distinct interface. Should we also add sum_distinct, max_distinct (maybe no sense, because the result of max_distinct is equal to max) and more?

AggregateFunction::Max,
vec![col("b")],
true,
None,
None,
)),
],
)?
.build()?;
// Should work
let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), COUNT(alias1), MAX(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn one_distinctand_and_two_common() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("a")],
vec![sum(col("c")), max(col("c")), count_distinct(col("b"))],
)?
.build()?;
// Should work
let expected = "Projection: test.a, SUM(alias2) AS SUM(test.c), MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, SUM(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2), MAX(alias3), COUNT(alias1)]] [a:UInt32, SUM(alias2):UInt64;N, MAX(alias3):UInt32;N, COUNT(alias1):Int64;N]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn one_distinct_and_one_common() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("c")],
vec![min(col("a")), count_distinct(col("b"))],
)?
.build()?;
// Should work
let expected = "Projection: test.c, MIN(alias2) AS MIN(test.a), COUNT(alias1) AS COUNT(DISTINCT test.b) [c:UInt32, MIN(test.a):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[test.c]], aggr=[[MIN(alias2), COUNT(alias1)]] [c:UInt32, MIN(alias2):UInt32;N, COUNT(alias1):Int64;N]\
\n Aggregate: groupBy=[[test.c, test.b AS alias1]], aggr=[[MIN(test.a) AS alias2]] [c:UInt32, alias1:UInt32, alias2:UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn common_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;

// SUM(a) FILTER (WHERE a > 5)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Sum,
vec![col("a")],
false,
Some(Box::new(col("a").gt(lit(5)))),
None,
));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) FILTER (WHERE test.a > Int32(5)), COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) FILTER (WHERE test.a > Int32(5)):UInt64;N, COUNT(DISTINCT test.b):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn distinct_with_filter() -> Result<()> {
let table_scan = test_table_scan()?;

// COUNT(DISTINCT a) FILTER (WHERE a > 5)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Count,
vec![col("a")],
true,
Some(Box::new(col("a").gt(lit(5)))),
None,
));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5))]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn common_with_order_by() -> Result<()> {
let table_scan = test_table_scan()?;

// SUM(a ORDER BY a)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Sum,
vec![col("a")],
false,
None,
Some(vec![col("a")]),
));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![expr, count_distinct(col("b"))])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a) ORDER BY [test.a], COUNT(DISTINCT test.b)]] [c:UInt32, SUM(test.a) ORDER BY [test.a]:UInt64;N, COUNT(DISTINCT test.b):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn distinct_with_order_by() -> Result<()> {
let table_scan = test_table_scan()?;

// COUNT(DISTINCT a ORDER BY a)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Count,
vec![col("a")],
true,
None,
Some(vec![col("a")]),
));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) ORDER BY [test.a]:Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn aggregate_with_filter_and_order_by() -> Result<()> {
let table_scan = test_table_scan()?;

// COUNT(DISTINCT a ORDER BY a) FILTER (WHERE a > 5)
let expr = Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Count,
vec![col("a")],
true,
Some(Box::new(col("a").gt(lit(5)))),
Some(vec![col("a")]),
));
let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(vec![col("c")], vec![sum(col("a")), expr])?
.build()?;
// Do nothing
let expected = "Aggregate: groupBy=[[test.c]], aggr=[[SUM(test.a), COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]]] [c:UInt32, SUM(test.a):UInt64;N, COUNT(DISTINCT test.a) FILTER (WHERE test.a > Int32(5)) ORDER BY [test.a]:Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}
}
Loading