Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix order by contains count in scalar subquery #13782

Merged
merged 2 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ use crate::plans::RelOperator;
use crate::plans::ScalarExpr;
use crate::plans::ScalarItem;
use crate::plans::Scan;
use crate::plans::Sort;
use crate::plans::UnionAll;
use crate::BaseTableColumn;
use crate::ColumnEntry;
Expand Down Expand Up @@ -165,9 +166,13 @@ impl SubqueryRewriter {
flatten_info,
need_cross_join,
),
RelOperator::Sort(_) => {
self.flatten_sort(plan, correlated_columns, flatten_info, need_cross_join)
}
RelOperator::Sort(sort) => self.flatten_sort(
plan,
sort,
correlated_columns,
flatten_info,
need_cross_join,
),

RelOperator::Limit(_) => {
self.flatten_limit(plan, correlated_columns, flatten_info, need_cross_join)
Expand Down Expand Up @@ -487,6 +492,7 @@ impl SubqueryRewriter {
fn flatten_sort(
&mut self,
plan: &SExpr,
sort: &Sort,
correlated_columns: &ColumnSet,
flatten_info: &mut FlattenInfo,
need_cross_join: bool,
Expand All @@ -498,6 +504,19 @@ impl SubqueryRewriter {
flatten_info,
need_cross_join,
)?;
// Check if sort contains `count() or distinct count()`.
if sort.items.iter().any(|item| {
let metadata = self.metadata.read();
let col = metadata.column(item.index);
if let ColumnEntry::DerivedColumn(derived_col) = col {
// A little tricky here, we'll check if sort items if a count aggregation function later.
xudong963 marked this conversation as resolved.
Show resolved Hide resolved
xudong963 marked this conversation as resolved.
Show resolved Hide resolved
derived_col.alias.to_lowercase().starts_with("count")
} else {
false
}
}) {
flatten_info.from_count_func = false;
}
Ok(SExpr::create_unary(
Arc::new(plan.plan().clone()),
Arc::new(flatten_plan),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,8 @@ impl SubqueryRewriter {
.build(),
});

let scalar = if flatten_info.from_count_func {
let scalar = if flatten_info.from_count_func && subquery.typ == SubqueryType::Scalar
{
// convert count aggregate function to `if(count() is not null, count(), 0)`
let is_not_null = ScalarExpr::FunctionCall(FunctionCall {
span: subquery.span,
Expand Down
123 changes: 123 additions & 0 deletions tests/sqllogictests/suites/mode/standalone/explain/explain.test
Original file line number Diff line number Diff line change
Expand Up @@ -1591,3 +1591,126 @@ drop table t1;

statement ok
drop table t2;

statement ok
create table t1(a int, b int, c varchar(20));

statement ok
create table t2(a int, b int, c varchar(20));

# scalar subquery and sort plan contains count() agg function.
query T
explain select * from t2 where c > (select c from t1 where t1.a = t2.a group by c order by count(a));
----
Filter
├── output columns: [t2.a (#0), t2.b (#1), t2.c (#2)]
├── filters: [is_true(t2.c (#2) > scalar_subquery_5 (#5))]
├── estimated rows: 0.00
└── HashJoin
├── output columns: [t2.a (#0), t2.b (#1), t2.c (#2), t1.c (#5)]
├── join type: LEFT SINGLE
├── build keys: [a (#3)]
├── probe keys: [a (#0)]
├── filters: []
├── estimated rows: 0.00
├── Sort(Build)
│ ├── output columns: [t1.c (#5), t1.a (#3), count(a) (#7)]
│ ├── sort keys: [count(a) ASC NULLS LAST]
│ ├── estimated rows: 0.00
│ └── EvalScalar
│ ├── output columns: [t1.c (#5), t1.a (#3), count(a) (#7)]
│ ├── expressions: [count(a) (#6)]
│ ├── estimated rows: 0.00
│ └── AggregateFinal
│ ├── output columns: [count(a) (#6), t1.c (#5), t1.a (#3)]
│ ├── group by: [c, a]
│ ├── aggregate functions: [count(a)]
│ ├── estimated rows: 0.00
│ └── AggregatePartial
│ ├── output columns: [count(a) (#6), #_group_by_key]
│ ├── group by: [c, a]
│ ├── aggregate functions: [count(a)]
│ ├── estimated rows: 0.00
│ └── Filter
│ ├── output columns: [t1.a (#3), t1.c (#5)]
│ ├── filters: [is_true(t1.a (#3) = a (#3))]
│ ├── estimated rows: 0.00
│ └── TableScan
│ ├── table: default.default.t1
│ ├── output columns: [a (#3), c (#5)]
│ ├── read rows: 0
│ ├── read bytes: 0
│ ├── partitions total: 0
│ ├── partitions scanned: 0
│ ├── push downs: [filters: [is_true(t1.a (#3) = a (#3))], limit: NONE]
│ └── estimated rows: 0.00
└── TableScan(Probe)
├── table: default.default.t2
├── output columns: [a (#0), b (#1), c (#2)]
├── read rows: 0
├── read bytes: 0
├── partitions total: 0
├── partitions scanned: 0
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 0.00

query T
explain select * from t2 where c > (select c from t1 where t1.a = t2.a group by c order by count(*));
----
Filter
├── output columns: [t2.a (#0), t2.b (#1), t2.c (#2)]
├── filters: [is_true(t2.c (#2) > scalar_subquery_5 (#5))]
├── estimated rows: 0.00
└── HashJoin
├── output columns: [t2.a (#0), t2.b (#1), t2.c (#2), t1.c (#5)]
├── join type: LEFT SINGLE
├── build keys: [a (#3)]
├── probe keys: [a (#0)]
├── filters: []
├── estimated rows: 0.00
├── Sort(Build)
│ ├── output columns: [t1.c (#5), t1.a (#3), COUNT(*) (#7)]
│ ├── sort keys: [COUNT(*) ASC NULLS LAST]
│ ├── estimated rows: 0.00
│ └── EvalScalar
│ ├── output columns: [t1.c (#5), t1.a (#3), COUNT(*) (#7)]
│ ├── expressions: [COUNT(*) (#6)]
│ ├── estimated rows: 0.00
│ └── AggregateFinal
│ ├── output columns: [COUNT(*) (#6), t1.c (#5), t1.a (#3)]
│ ├── group by: [c, a]
│ ├── aggregate functions: [count()]
│ ├── estimated rows: 0.00
│ └── AggregatePartial
│ ├── output columns: [COUNT(*) (#6), #_group_by_key]
│ ├── group by: [c, a]
│ ├── aggregate functions: [count()]
│ ├── estimated rows: 0.00
│ └── Filter
│ ├── output columns: [t1.a (#3), t1.c (#5)]
│ ├── filters: [is_true(t1.a (#3) = a (#3))]
│ ├── estimated rows: 0.00
│ └── TableScan
│ ├── table: default.default.t1
│ ├── output columns: [a (#3), c (#5)]
│ ├── read rows: 0
│ ├── read bytes: 0
│ ├── partitions total: 0
│ ├── partitions scanned: 0
│ ├── push downs: [filters: [is_true(t1.a (#3) = a (#3))], limit: NONE]
│ └── estimated rows: 0.00
└── TableScan(Probe)
├── table: default.default.t2
├── output columns: [a (#0), b (#1), c (#2)]
├── read rows: 0
├── read bytes: 0
├── partitions total: 0
├── partitions scanned: 0
├── push downs: [filters: [], limit: NONE]
└── estimated rows: 0.00

statement ok
drop table t1;

statement ok
drop table t2;
Loading