Skip to content

Commit

Permalink
fix: fix order by contains count in scalar subquery (databendlabs#13782)
Browse files Browse the repository at this point in the history
* fix: fix order by contains count in scalar subquery

* Update src/query/sql/src/planner/optimizer/heuristic/decorrelate/flatten_plan.rs
  • Loading branch information
xudong963 authored and andylokandy committed Nov 27, 2023
1 parent f7fec7a commit e7cd5c3
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use crate::plans::RelOperator;
use crate::plans::ScalarExpr;
use crate::plans::ScalarItem;
use crate::plans::Scan;
use crate::plans::Sort;
use crate::plans::UnionAll;
use crate::BaseTableColumn;
use crate::ColumnEntry;
Expand Down Expand Up @@ -165,9 +166,13 @@ impl SubqueryRewriter {
flatten_info,
need_cross_join,
),
RelOperator::Sort(_) => {
self.flatten_sort(plan, correlated_columns, flatten_info, need_cross_join)
}
RelOperator::Sort(sort) => self.flatten_sort(
plan,
sort,
correlated_columns,
flatten_info,
need_cross_join,
),

RelOperator::Limit(_) => {
self.flatten_limit(plan, correlated_columns, flatten_info, need_cross_join)
Expand Down Expand Up @@ -487,6 +492,7 @@ impl SubqueryRewriter {
fn flatten_sort(
&mut self,
plan: &SExpr,
sort: &Sort,
correlated_columns: &ColumnSet,
flatten_info: &mut FlattenInfo,
need_cross_join: bool,
Expand All @@ -498,6 +504,19 @@ impl SubqueryRewriter {
flatten_info,
need_cross_join,
)?;
// Check if sort contains `count() or distinct count()`.
if sort.items.iter().any(|item| {
let metadata = self.metadata.read();
let col = metadata.column(item.index);
if let ColumnEntry::DerivedColumn(derived_col) = col {
// A little tricky here, we'll check if a sort item is a count aggregation function later.
derived_col.alias.to_lowercase().starts_with("count")
} else {
false
}
}) {
flatten_info.from_count_func = false;
}
Ok(SExpr::create_unary(
Arc::new(plan.plan().clone()),
Arc::new(flatten_plan),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ impl SubqueryRewriter {
.build(),
});

let scalar = if flatten_info.from_count_func {
let scalar = if flatten_info.from_count_func && subquery.typ == SubqueryType::Scalar
{
// convert count aggregate function to `if(count() is not null, count(), 0)`
let is_not_null = ScalarExpr::FunctionCall(FunctionCall {
span: subquery.span,
Expand Down
123 changes: 123 additions & 0 deletions tests/sqllogictests/suites/mode/standalone/explain/explain.test
Original file line number Diff line number Diff line change
Expand Up @@ -1591,3 +1591,126 @@ drop table t1;

statement ok
drop table t2;

statement ok
create table t1(a int, b int, c varchar(20));

statement ok
create table t2(a int, b int, c varchar(20));

# scalar subquery and sort plan contains count() agg function.
query T
explain select * from t2 where c > (select c from t1 where t1.a = t2.a group by c order by count(a));
----
Filter
├── output columns: [t2.a (#0), t2.b (#1), t2.c (#2)]
├── filters: [is_true(t2.c (#2) > scalar_subquery_5 (#5))]
├── estimated rows: 0.00
└── HashJoin
├── output columns: [t2.a (#0), t2.b (#1), t2.c (#2), t1.c (#5)]
├── join type: LEFT SINGLE
├── build keys: [a (#3)]
├── probe keys: [a (#0)]
├── filters: []
├── estimated rows: 0.00
├── Sort(Build)
│ ├── output columns: [t1.c (#5), t1.a (#3), count(a) (#7)]
│ ├── sort keys: [count(a) ASC NULLS LAST]
│ ├── estimated rows: 0.00
│ └── EvalScalar
│ ├── output columns: [t1.c (#5), t1.a (#3), count(a) (#7)]
│ ├── expressions: [count(a) (#6)]
│ ├── estimated rows: 0.00
│ └── AggregateFinal
│ ├── output columns: [count(a) (#6), t1.c (#5), t1.a (#3)]
│ ├── group by: [c, a]
│ ├── aggregate functions: [count(a)]
│ ├── estimated rows: 0.00
│ └── AggregatePartial
│ ├── output columns: [count(a) (#6), #_group_by_key]
│ ├── group by: [c, a]
│ ├── aggregate functions: [count(a)]
│ ├── estimated rows: 0.00
│ └── Filter
│ ├── output columns: [t1.a (#3), t1.c (#5)]
│ ├── filters: [is_true(t1.a (#3) = a (#3))]
│ ├── estimated rows: 0.00
│ └── TableScan
│ ├── table: default.default.t1
│ ├── output columns: [a (#3), c (#5)]
│ ├── read rows: 0
│ ├── read bytes: 0
│ ├── partitions total: 0
│ ├── partitions scanned: 0
│ ├── push downs: [filters: [is_true(t1.a (#3) = a (#3))], limit: NONE]
│ └── estimated rows: 0.00
└── TableScan(Probe)
├── table: default.default.t2
├── output columns: [a (#0), b (#1), c (#2)]
├── read rows: 0
├── read bytes: 0
├── partitions total: 0
├── partitions scanned: 0
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 0.00

query T
explain select * from t2 where c > (select c from t1 where t1.a = t2.a group by c order by count(*));
----
Filter
├── output columns: [t2.a (#0), t2.b (#1), t2.c (#2)]
├── filters: [is_true(t2.c (#2) > scalar_subquery_5 (#5))]
├── estimated rows: 0.00
└── HashJoin
├── output columns: [t2.a (#0), t2.b (#1), t2.c (#2), t1.c (#5)]
├── join type: LEFT SINGLE
├── build keys: [a (#3)]
├── probe keys: [a (#0)]
├── filters: []
├── estimated rows: 0.00
├── Sort(Build)
│ ├── output columns: [t1.c (#5), t1.a (#3), COUNT(*) (#7)]
│ ├── sort keys: [COUNT(*) ASC NULLS LAST]
│ ├── estimated rows: 0.00
│ └── EvalScalar
│ ├── output columns: [t1.c (#5), t1.a (#3), COUNT(*) (#7)]
│ ├── expressions: [COUNT(*) (#6)]
│ ├── estimated rows: 0.00
│ └── AggregateFinal
│ ├── output columns: [COUNT(*) (#6), t1.c (#5), t1.a (#3)]
│ ├── group by: [c, a]
│ ├── aggregate functions: [count()]
│ ├── estimated rows: 0.00
│ └── AggregatePartial
│ ├── output columns: [COUNT(*) (#6), #_group_by_key]
│ ├── group by: [c, a]
│ ├── aggregate functions: [count()]
│ ├── estimated rows: 0.00
│ └── Filter
│ ├── output columns: [t1.a (#3), t1.c (#5)]
│ ├── filters: [is_true(t1.a (#3) = a (#3))]
│ ├── estimated rows: 0.00
│ └── TableScan
│ ├── table: default.default.t1
│ ├── output columns: [a (#3), c (#5)]
│ ├── read rows: 0
│ ├── read bytes: 0
│ ├── partitions total: 0
│ ├── partitions scanned: 0
│ ├── push downs: [filters: [is_true(t1.a (#3) = a (#3))], limit: NONE]
│ └── estimated rows: 0.00
└── TableScan(Probe)
├── table: default.default.t2
├── output columns: [a (#0), b (#1), c (#2)]
├── read rows: 0
├── read bytes: 0
├── partitions total: 0
├── partitions scanned: 0
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 0.00

statement ok
drop table t1;

statement ok
drop table t2;

0 comments on commit e7cd5c3

Please sign in to comment.