Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support no distinct aggregate sum/min/max in single_distinct_to_group_by rule #8124

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 123 additions & 34 deletions datafusion/optimizer/src/single_distinct_to_groupby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use crate::{OptimizerConfig, OptimizerRule};

use datafusion_common::Result;
use datafusion_expr::{
aggregate_function::AggregateFunction::{Max, Min, Sum},
col,
expr::AggregateFunction,
logical_plan::{Aggregate, LogicalPlan},
Expand All @@ -34,17 +35,19 @@ use hashbrown::HashSet;

/// single distinct to group by optimizer rule
/// ```text
/// SELECT F1(DISTINCT s),F2(DISTINCT s)
/// ...
/// GROUP BY k
/// Before:
/// SELECT a, COUNT(DINSTINCT b), SUM(c)
/// FROM t
/// GROUP BY a
///
/// Into
///
/// SELECT F1(alias1),F2(alias1)
/// After:
/// SELECT a, COUNT(alias1), SUM(alias2)
/// FROM (
/// SELECT s as alias1, k ... GROUP BY s, k
/// SELECT a, b as alias1, SUM(c) as alias2
/// FROM t
/// GROUP BY a, b
/// )
/// GROUP BY k
/// GROUP BY a
/// ```
#[derive(Default)]
pub struct SingleDistinctToGroupBy {}
Expand All @@ -58,27 +61,37 @@ impl SingleDistinctToGroupBy {
}
}

/// Check whether all aggregate exprs are distinct on a single field.
/// Check whether all distinct aggregate exprs are distinct on a single field.
fn is_single_distinct_agg(plan: &LogicalPlan) -> Result<bool> {
match plan {
LogicalPlan::Aggregate(Aggregate { aggr_expr, .. }) => {
let mut fields_set = HashSet::new();
let mut distinct_count = 0;
let mut aggregate_count = 0;
for expr in aggr_expr {
if let Expr::AggregateFunction(AggregateFunction {
distinct, args, ..
fun,
distinct,
args,
filter,
..
}) = expr
{
if *distinct {
distinct_count += 1;
}
for e in args {
fields_set.insert(e.canonical_name());
match filter {
Some(_) => return Ok(false),
Copy link
Contributor Author

@haohuaijin haohuaijin Nov 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this pr, we also don't support filter in single_distinct_to_group_by rule. But we forget check it.

None => {
aggregate_count += 1;
if *distinct {
for e in args {
fields_set.insert(e.canonical_name());
}
} else if !matches!(fun, Sum | Min | Max) {
return Ok(false);
}
}
}
}
}
let res = distinct_count == aggr_expr.len() && fields_set.len() == 1;
Ok(res)
Ok(fields_set.len() == 1 && aggregate_count == aggr_expr.len())
}
_ => Ok(false),
}
Expand Down Expand Up @@ -151,31 +164,60 @@ impl OptimizerRule for SingleDistinctToGroupBy {
.collect::<Vec<_>>();

// replace the distinct arg with alias
let mut index = 1;
let mut group_fields_set = HashSet::new();
let mut inner_aggr_exprs = vec![];
let outer_aggr_exprs = aggr_expr
.iter()
.map(|aggr_expr| match aggr_expr {
Expr::AggregateFunction(AggregateFunction {
fun,
args,
filter,
order_by,
distinct,
..
}) => {
// is_single_distinct_agg ensure args.len=1
if group_fields_set.insert(args[0].display_name()?) {
if *distinct
&& group_fields_set.insert(args[0].display_name()?)
{
inner_group_exprs.push(
args[0].clone().alias(SINGLE_DISTINCT_ALIAS),
);
}
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
vec![col(SINGLE_DISTINCT_ALIAS)],
false, // intentional to remove distinct here
filter.clone(),
order_by.clone(),
))
.alias(aggr_expr.display_name()?))

// if the aggregate function is not distinct, we need to rewrite it like two phase aggregation
if !(*distinct) {
index += 1;
let alias_str = format!("alias{}", index);
let inner_expr =
Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
args.clone(),
false,
None,
order_by.clone(),
))
.alias(&alias_str);
inner_aggr_exprs.push(inner_expr);
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
vec![col(&alias_str)],
false,
None,
order_by.clone(),
))
.alias(aggr_expr.display_name()?))
} else {
Ok(Expr::AggregateFunction(AggregateFunction::new(
fun.clone(),
vec![col(SINGLE_DISTINCT_ALIAS)],
false, // intentional to remove distinct here
None,
order_by.clone(),
))
.alias(aggr_expr.display_name()?))
}
}
_ => Ok(aggr_expr.clone()),
})
Expand All @@ -185,7 +227,7 @@ impl OptimizerRule for SingleDistinctToGroupBy {
let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new(
input.clone(),
inner_group_exprs,
Vec::new(),
inner_aggr_exprs,
)?);

Ok(Some(LogicalPlan::Aggregate(Aggregate::try_new(
Expand Down Expand Up @@ -217,7 +259,7 @@ mod tests {
use datafusion_expr::expr;
use datafusion_expr::expr::GroupingSet;
use datafusion_expr::{
col, count, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max,
col, count_distinct, lit, logical_plan::builder::LogicalPlanBuilder, max, sum,
AggregateFunction,
};

Expand Down Expand Up @@ -396,20 +438,67 @@ mod tests {
assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn two_distinct_and_one_common() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("a")],
vec![
sum(col("c")),
count_distinct(col("b")),
Expr::AggregateFunction(expr::AggregateFunction::new(
AggregateFunction::Max,
vec![col("b")],
true,
None,
None,
)),
],
)?
.build()?;
// Should work
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2) AS SUM(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b), MAX(alias1) AS MAX(DISTINCT test.b)]] [a:UInt32, SUM(test.c):UInt64;N, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn one_distinctand_two() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("a")],
vec![sum(col("c")), max(col("c")), count_distinct(col("b"))],
)?
.build()?;
// Should work
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(alias2) AS SUM(test.c), MAX(alias3) AS MAX(test.c), COUNT(alias1) AS COUNT(DISTINCT test.b)]] [a:UInt32, SUM(test.c):UInt64;N, MAX(test.c):UInt32;N, COUNT(DISTINCT test.b):Int64;N]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2, MAX(test.c) AS alias3]] [a:UInt32, alias1:UInt32, alias2:UInt64;N, alias3:UInt32;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}

#[test]
fn distinct_and_common() -> Result<()> {
let table_scan = test_table_scan()?;

let plan = LogicalPlanBuilder::from(table_scan)
.aggregate(
vec![col("a")],
vec![count_distinct(col("b")), count(col("c"))],
vec![count_distinct(col("b")), sum(col("c"))],
)?
.build()?;

// Do nothing
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
// Should work
let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT test.b), SUM(alias2) AS SUM(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, SUM(test.c):UInt64;N]\
\n Aggregate: groupBy=[[test.a, test.b AS alias1]], aggr=[[SUM(test.c) AS alias2]] [a:UInt32, alias1:UInt32, alias2:UInt64;N]\
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";

assert_optimized_plan_equal(&plan, expected)
}
Expand Down
51 changes: 51 additions & 0 deletions datafusion/sqllogictest/test_files/groupby.slt
Original file line number Diff line number Diff line change
Expand Up @@ -3841,3 +3841,54 @@ ProjectionExec: expr=[SUM(DISTINCT t1.x)@1 as SUM(DISTINCT t1.x), MAX(DISTINCT t
------------------AggregateExec: mode=Partial, gby=[y@1 as y, CAST(t1.x AS Float64)t1.x@0 as alias1], aggr=[]
--------------------ProjectionExec: expr=[CAST(x@0 AS Float64) as CAST(t1.x AS Float64)t1.x, y@1 as y]
----------------------MemoryExec: partitions=1, partition_sizes=[1]

statement ok
CREATE EXTERNAL TABLE aggregate_test_100 (
c1 VARCHAR NOT NULL,
c2 TINYINT NOT NULL,
c3 SMALLINT NOT NULL,
c4 SMALLINT,
c5 INT,
c6 BIGINT NOT NULL,
c7 SMALLINT NOT NULL,
c8 INT NOT NULL,
c9 INT UNSIGNED NOT NULL,
c10 BIGINT UNSIGNED NOT NULL,
c11 FLOAT NOT NULL,
c12 DOUBLE NOT NULL,
c13 VARCHAR NOT NULL
)
STORED AS CSV
WITH HEADER ROW
LOCATION '../../testing/data/csv/aggregate_test_100.csv'

query TIIII
SELECT c1, count(distinct c2), min(distinct c2), min(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1;
----
a 5 1 -101 32064
b 5 1 -117 25286
c 5 1 -117 29106
d 5 1 -99 31106
e 5 1 -95 32514

query TT
EXPLAIN SELECT c1, count(distinct c2), min(distinct c2), sum(c3), max(c4) FROM aggregate_test_100 GROUP BY c1 ORDER BY c1;
----
logical_plan
Sort: aggregate_test_100.c1 ASC NULLS LAST
--Aggregate: groupBy=[[aggregate_test_100.c1]], aggr=[[COUNT(alias1) AS COUNT(DISTINCT aggregate_test_100.c2), MIN(alias1) AS MIN(DISTINCT aggregate_test_100.c2), SUM(alias2) AS SUM(aggregate_test_100.c3), MAX(alias3) AS MAX(aggregate_test_100.c4)]]
----Aggregate: groupBy=[[aggregate_test_100.c1, aggregate_test_100.c2 AS alias1]], aggr=[[SUM(CAST(aggregate_test_100.c3 AS Int64)) AS alias2, MAX(aggregate_test_100.c4) AS alias3]]
------TableScan: aggregate_test_100 projection=[c1, c2, c3, c4]
physical_plan
SortPreservingMergeExec: [c1@0 ASC NULLS LAST]
--SortExec: expr=[c1@0 ASC NULLS LAST]
----AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(DISTINCT aggregate_test_100.c2), MIN(DISTINCT aggregate_test_100.c2), SUM(aggregate_test_100.c3), MAX(aggregate_test_100.c4)]
------CoalesceBatchesExec: target_batch_size=2
--------RepartitionExec: partitioning=Hash([c1@0], 8), input_partitions=8
----------AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(DISTINCT aggregate_test_100.c2), MIN(DISTINCT aggregate_test_100.c2), SUM(aggregate_test_100.c3), MAX(aggregate_test_100.c4)]
------------AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1, alias1@1 as alias1], aggr=[alias2, alias3]
--------------CoalesceBatchesExec: target_batch_size=2
----------------RepartitionExec: partitioning=Hash([c1@0, alias1@1], 8), input_partitions=8
------------------AggregateExec: mode=Partial, gby=[c1@0 as c1, c2@1 as alias1], aggr=[alias2, alias3]
--------------------RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1
----------------------CsvExec: file_groups={1 group: [[WORKSPACE_ROOT/testing/data/csv/aggregate_test_100.csv]]}, projection=[c1, c2, c3, c4], has_header=true
29 changes: 28 additions & 1 deletion datafusion/substrait/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ async fn simple_aggregate() -> Result<()> {

#[tokio::test]
async fn aggregate_distinct_with_having() -> Result<()> {
roundtrip("SELECT a, count(distinct b) FROM data GROUP BY a, c HAVING count(b) > 100")
roundtrip("SELECT a, count(distinct b), sum(distinct e) FROM data GROUP BY a, c HAVING count(b) > 100")
.await
}

Expand Down Expand Up @@ -267,6 +267,33 @@ async fn select_distinct_two_fields() -> Result<()> {
.await
}

#[tokio::test]
async fn simple_distinct_aggregate() -> Result<()> {
test_alias(
"SELECT a, COUNT(DISTINCT b) FROM data GROUP BY a",
"SELECT a, COUNT(b) FROM (SELECT a, b FROM data GROUP BY a, b) GROUP BY a",
)
.await
}

#[tokio::test]
async fn select_distinct_aggregate_two_fields() -> Result<()> {
test_alias(
"SELECT a, COUNT(DISTINCT b), MAX(DISTINCT b) FROM data GROUP BY a",
"SELECT a, COUNT(b), MAX(b) FROM (SELECT a, b FROM data GROUP BY a, b) GROUP BY a",
)
.await
}

#[tokio::test]
async fn select_distinct_aggregate_and_no_distinct_aggregate() -> Result<()> {
test_alias(
"SELECT a, COUNT(DISTINCT b), SUM(e) FROM data GROUP by a",
"SELECT a, COUNT(b), SUM(\"SUM(data.e)\") FROM (SELECT a, b, SUM(e) FROM data GROUP BY a, b) GROUP BY a",
)
.await
}

#[tokio::test]
async fn simple_alias() -> Result<()> {
test_alias("SELECT d1.a, d1.b FROM data d1", "SELECT a, b FROM data").await
Expand Down