Skip to content

Commit

Permalink
fix(optimizer): fix DistinctAggRule to correctly handle queries wit…
Browse files Browse the repository at this point in the history
…h intersected arguments between distinct and non-distincted agg calls (#7688)

Fixes #7680.

Approved-By: st1page
  • Loading branch information
stdrc authored Feb 7, 2023
1 parent 5ffba37 commit d785ab2
Show file tree
Hide file tree
Showing 6 changed files with 191 additions and 20 deletions.
43 changes: 29 additions & 14 deletions e2e_test/batch/aggregate/distinct.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,40 @@ statement ok
create table t (v1 int, v2 int, v3 int);

statement ok
insert into t values (1, 2, 3), (4, 3, 2), (4, 2, 3), (1, 3, 2);
insert into t values (1,2,3), (1,2,4), (5,3,8), (2,4,4);

query I rowsort
select distinct v1 from t;
query I
select count(distinct v1) from t;
----
1
4
3

query I
select distinct sum(v1) from t group by v2;
query II rowsort
select v2, count(distinct v1) from t group by v2;
----
5
2 1
3 1
4 1

# v2, v3 can be either 2, 3 or 3, 2
query I
select distinct on(v1) v2 + v3 from t order by v1;
query III rowsort
select v2, count(distinct v1), max(v3) from t group by v2;
----
2 1 4
3 1 8
4 1 4

query IIII rowsort
select v1, count(distinct v2), count(distinct v3), max(v2) from t group by v1;
----
1 1 2 2
2 1 1 4
5 1 1 3

query IIIII rowsort
select v1, count(distinct v2), min(distinct v2), count(distinct v3), max(v3) from t group by v1;
----
5
5
1 1 2 2 4
2 1 4 1 4
5 1 3 1 8

statement ok
drop table t
drop table t;
29 changes: 29 additions & 0 deletions e2e_test/batch/basic/distinct.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

statement ok
create table t (v1 int, v2 int, v3 int);

statement ok
insert into t values (1, 2, 3), (4, 3, 2), (4, 2, 3), (1, 3, 2);

query I rowsort
select distinct v1 from t;
----
1
4

query I
select distinct sum(v1) from t group by v2;
----
5

# v2, v3 can be either 2, 3 or 3, 2
query I
select distinct on(v1) v2 + v3 from t order by v1;
----
5
5

statement ok
drop table t
74 changes: 74 additions & 0 deletions e2e_test/streaming/distinct_agg.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

statement ok
create table t (v1 int, v2 int, v3 int);

statement ok
create materialized view mv1 as select count(distinct v1) as c_d_v1 from t;

statement ok
create materialized view mv2 as select v2, count(distinct v1) as c_d_v1 from t group by v2;

statement ok
create materialized view mv3 as select v2, count(distinct v1) as c_d_v1, max(v3) as max_v3 from t group by v2;

statement ok
create materialized view mv4 as select v1, count(distinct v2) as c_d_v2, count(distinct v3) as c_d_v3, max(v2) as max_v2 from t group by v1;

statement ok
create materialized view mv5 as select v1, count(distinct v2) as c_d_v2, min(distinct v2) as min_d_v2, count(distinct v3) as c_d_v3, max(v3) as max_v3 from t group by v1;

statement ok
insert into t values (1,2,3), (1,2,4), (5,3,8), (2,4,4);

query I
select * from mv1;
----
3

query II rowsort
select * from mv2;
----
2 1
3 1
4 1

query III rowsort
select * from mv3;
----
2 1 4
3 1 8
4 1 4

query IIII rowsort
select * from mv4;
----
1 1 2 2
2 1 1 4
5 1 1 3

query IIIII rowsort
select * from mv5;
----
1 1 2 2 4
2 1 4 1 4
5 1 3 1 8

statement ok
drop materialized view mv1;

statement ok
drop materialized view mv2;

statement ok
drop materialized view mv3;

statement ok
drop materialized view mv4;

statement ok
drop materialized view mv5;

statement ok
drop table t;
52 changes: 51 additions & 1 deletion src/frontend/planner_test/tests/testdata/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -705,14 +705,21 @@
└─LogicalAgg { group_key: [t.a, t.b, t.c, flag], aggs: [count filter((t.c < 100:Int32))] }
└─LogicalExpand { column_subsets: [[t.a, t.b], [t.a, t.c]] }
└─LogicalScan { table: t, columns: [t.a, t.b, t.c] }
- name: distinct agg and non-disintct agg
- name: single distinct agg and non-disintct agg
sql: |
create table t(a int, b int, c int);
select a, count(distinct b) as distinct_b_num, sum(c) as sum_c from t group by a;
optimized_logical_plan: |
LogicalAgg { group_key: [t.a], aggs: [count(t.b), sum(sum(t.c))] }
└─LogicalAgg { group_key: [t.a, t.b], aggs: [sum(t.c)] }
└─LogicalScan { table: t, columns: [t.a, t.b, t.c] }
batch_plan: |
BatchExchange { order: [], dist: Single }
└─BatchHashAgg { group_key: [t.a], aggs: [count(t.b), sum(sum(t.c))] }
└─BatchExchange { order: [], dist: HashShard(t.a) }
└─BatchHashAgg { group_key: [t.a, t.b], aggs: [sum(t.c)] }
└─BatchExchange { order: [], dist: HashShard(t.a, t.b) }
└─BatchScan { table: t, columns: [t.a, t.b, t.c], distribution: SomeShard }
stream_plan: |
StreamMaterialize { columns: [a, distinct_b_num, sum_c], pk_columns: [a] }
└─StreamProject { exprs: [t.a, count(t.b), sum(sum(t.c))] }
Expand All @@ -722,6 +729,33 @@
└─StreamHashAgg { group_key: [t.a, t.b], aggs: [count, sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a, t.b) }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- name: distinct agg and non-disintct agg with intersected argument
sql: |
create table t(a int, b int, c int);
select a, count(distinct b) as distinct_b_num, count(distinct c) as distinct_c_sum, sum(c) as sum_c from t group by a;
optimized_logical_plan: |
LogicalAgg { group_key: [t.a], aggs: [count(t.b) filter((flag = 1:Int64)), count(t.c) filter((flag = 0:Int64)), sum(sum(t.c)) filter((flag = 0:Int64))] }
└─LogicalAgg { group_key: [t.a, t.b, t.c, flag], aggs: [sum(t.c)] }
└─LogicalExpand { column_subsets: [[t.a, t.c], [t.a, t.b]] }
└─LogicalScan { table: t, columns: [t.a, t.b, t.c] }
batch_plan: |
BatchExchange { order: [], dist: Single }
└─BatchHashAgg { group_key: [t.a], aggs: [count(t.b) filter((flag = 1:Int64)), count(t.c) filter((flag = 0:Int64)), sum(sum(t.c)) filter((flag = 0:Int64))] }
└─BatchExchange { order: [], dist: HashShard(t.a) }
└─BatchHashAgg { group_key: [t.a, t.b, t.c, flag], aggs: [sum(t.c)] }
└─BatchExchange { order: [], dist: HashShard(t.a, t.b, t.c, flag) }
└─BatchExpand { column_subsets: [[t.a, t.c], [t.a, t.b]] }
└─BatchScan { table: t, columns: [t.a, t.b, t.c], distribution: SomeShard }
stream_plan: |
StreamMaterialize { columns: [a, distinct_b_num, distinct_c_sum, sum_c], pk_columns: [a] }
└─StreamProject { exprs: [t.a, count(t.b) filter((flag = 1:Int64)), count(t.c) filter((flag = 0:Int64)), sum(sum(t.c)) filter((flag = 0:Int64))] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(t.b) filter((flag = 1:Int64)), count(t.c) filter((flag = 0:Int64)), sum(sum(t.c)) filter((flag = 0:Int64))] }
└─StreamExchange { dist: HashShard(t.a) }
└─StreamProject { exprs: [t.a, t.b, t.c, flag, sum(t.c)] }
└─StreamHashAgg { group_key: [t.a, t.b, t.c, flag], aggs: [count, sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a, t.b, t.c, flag) }
└─StreamExpand { column_subsets: [[t.a, t.c], [t.a, t.b]] }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- name: distinct agg with filter
sql: |
create table t(a int, b int, c int);
Expand All @@ -730,6 +764,22 @@
LogicalAgg { group_key: [t.a], aggs: [count(t.b) filter((count filter((t.b < 100:Int32)) > 0:Int64)), sum(sum(t.c))] }
└─LogicalAgg { group_key: [t.a, t.b], aggs: [count filter((t.b < 100:Int32)), sum(t.c)] }
└─LogicalScan { table: t, columns: [t.a, t.b, t.c] }
batch_plan: |
BatchExchange { order: [], dist: Single }
└─BatchHashAgg { group_key: [t.a], aggs: [count(t.b) filter((count filter((t.b < 100:Int32)) > 0:Int64)), sum(sum(t.c))] }
└─BatchExchange { order: [], dist: HashShard(t.a) }
└─BatchHashAgg { group_key: [t.a, t.b], aggs: [count filter((t.b < 100:Int32)), sum(t.c)] }
└─BatchExchange { order: [], dist: HashShard(t.a, t.b) }
└─BatchScan { table: t, columns: [t.a, t.b, t.c], distribution: SomeShard }
stream_plan: |
StreamMaterialize { columns: [a, count, sum], pk_columns: [a] }
└─StreamProject { exprs: [t.a, count(t.b) filter((count filter((t.b < 100:Int32)) > 0:Int64)), sum(sum(t.c))] }
└─StreamHashAgg { group_key: [t.a], aggs: [count, count(t.b) filter((count filter((t.b < 100:Int32)) > 0:Int64)), sum(sum(t.c))] }
└─StreamExchange { dist: HashShard(t.a) }
└─StreamProject { exprs: [t.a, t.b, count filter((t.b < 100:Int32)), sum(t.c)] }
└─StreamHashAgg { group_key: [t.a, t.b], aggs: [count, count filter((t.b < 100:Int32)), sum(t.c)] }
└─StreamExchange { dist: HashShard(t.a, t.b) }
└─StreamTableScan { table: t, columns: [t.a, t.b, t.c, t._row_id], pk: [t._row_id], dist: UpstreamHashShard(t._row_id) }
- name: non-distinct agg with filter
sql: |
create table t(a int, b int, c int);
Expand Down
3 changes: 1 addition & 2 deletions src/frontend/planner_test/tests/testdata/expr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -551,5 +551,4 @@
└─BatchProject { exprs: [((1:Int32 + 2:Int32) + t.v1)] }
└─BatchScan { table: t, columns: [t.v1], distribution: SomeShard }
- name: const_eval of division by 0 error
sql: |
select 1 / 0 t1;
sql: select 1 / 0 t1;
10 changes: 7 additions & 3 deletions src/frontend/src/optimizer/rule/distinct_agg_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ impl DistinctAggRule {
column_subsets.push(subset);
}

let mut num_of_subsets_for_distinct_agg = 0;
distinct_aggs.iter().for_each(|agg_call| {
let subset = {
let mut subset = FixedBitSet::from_iter(group_keys.iter().cloned());
Expand All @@ -106,14 +105,19 @@ impl DistinctAggRule {
flag_values.push(flag_value);
hash_map.insert(subset.clone(), flag_value);
column_subsets.push(subset);
num_of_subsets_for_distinct_agg += 1;
}
});

if num_of_subsets_for_distinct_agg <= 1 {
let n_different_distinct = distinct_aggs
.iter()
.unique_by(|agg_call| agg_call.input_indices())
.count();
assert_ne!(n_different_distinct, 0); // since `distinct_aggs` is not empty here
if n_different_distinct == 1 {
// no need to have expand if there is only one distinct aggregates.
return Some((input, flag_values, false));
}

let expand = LogicalExpand::create(input, column_subsets);
// manual version of column pruning for expand.
let project = Self::build_project(input_schema_len, expand, group_keys, agg_calls);
Expand Down

0 comments on commit d785ab2

Please sign in to comment.