Skip to content

Commit

Permalink
fix(binder): ExprVisitor and ExprMutator should visit agg_call order_…
Browse files Browse the repository at this point in the history
…by and filter (#5664)

* fix(binder): ExprVisitor and ExprMutator should visit agg_call order_by and filter

* remove redundant agg-table-function check

* more planner tests

* test case in original issue

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
xiangjinwu and mergify[bot] authored Sep 30, 2022
1 parent e6a0ac7 commit 1178dbe
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 33 deletions.
8 changes: 8 additions & 0 deletions e2e_test/batch/aggregate/string_agg.slt.part
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,13 @@ select string_agg(v1, ',' order by v3) filter (where v2 > 0) from t
----
aaa,ddd

query T
SELECT (SELECT STRING_AGG(v1, ',' ORDER BY strings.v1, v1 desc) FROM t) FROM t AS strings;
----
ddd,ccc,bbb,aaa
ddd,ccc,bbb,aaa
ddd,ccc,bbb,aaa
ddd,ccc,bbb,aaa

statement ok
drop table t
47 changes: 42 additions & 5 deletions src/frontend/planner_test/tests/testdata/agg.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -749,14 +749,25 @@
planner_error: |-
Feature is not yet implemented: subquery inside aggregation calls order by
No tracking issue yet. Feel free to submit a feature request at https://github.com/risingwavelabs/risingwave/issues/new?labels=type%2Ffeature&template=feature_request.yml
- name: agg order by - agg
- name: agg order by - agg (correlated in having)
sql: |
create table a (a1 int, a2 int);
create table sb (b1 varchar, b2 varchar);
select 1 from a having exists(
select string_agg(b1, '' order by min(a1)) from sb -- valid in PostgreSQL
-- select string_agg('', '' order by min(a1)) from sb -- NOT valid in PostgreSQL
);
planner_error: |-
Feature is not yet implemented: correlated subquery in HAVING or SELECT with agg
Tracking issue: https://github.com/risingwavelabs/risingwave/issues/2275
- name: agg order by - agg (correlated in where)
sql: |
/* This case is NOT valid in PostgreSQL */
create table a (a1 int, a2 int);
create table sb (b1 varchar, b2 varchar);
select 1 from a where exists(
select string_agg(b1, '' order by min(a1)) from sb
);
planner_error: |-
Feature is not yet implemented: aggregate function inside aggregation calls order by
No tracking issue yet. Feel free to submit a feature request at https://github.com/risingwavelabs/risingwave/issues/new?labels=type%2Ffeature&template=feature_request.yml
Expand Down Expand Up @@ -793,8 +804,8 @@
create table a (a1 int, a2 int);
select count(a1 + unnest(array[1])) from a;
planner_error: |-
Feature is not yet implemented: Table functions in agg call or group by is not supported yet
Tracking issue: https://github.com/risingwavelabs/risingwave/issues/3814
Feature is not yet implemented: table function inside aggregation calls
No tracking issue yet. Feel free to submit a feature request at https://github.com/risingwavelabs/risingwave/issues/new?labels=type%2Ffeature&template=feature_request.yml
- name: group by - subquery
sql: |
/* This case is valid in PostgreSQL */
Expand All @@ -820,5 +831,31 @@
create table a (a1 int, a2 int);
select count(a1) from a group by unnest(array[1]);
planner_error: |-
Feature is not yet implemented: Table functions in agg call or group by is not supported yet
Tracking issue: https://github.com/risingwavelabs/risingwave/issues/3814
Feature is not yet implemented: table function inside GROUP BY
No tracking issue yet. Feel free to submit a feature request at https://github.com/risingwavelabs/risingwave/issues/new?labels=type%2Ffeature&template=feature_request.yml
- name: post-agg project set - ok
sql: |
create table t (v1 int, v2 int);
select min(v1), unnest(array[2, max(v2)]) from t;
logical_plan: |
LogicalProject { exprs: [min(t.v1), Unnest(Array(2:Int32, $1))] }
└─LogicalProjectSet { select_list: [$0, Unnest(Array(2:Int32, $1))] }
└─LogicalAgg { aggs: [min(t.v1), max(t.v2)] }
└─LogicalProject { exprs: [t.v1, t.v2] }
└─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id] }
- name: post-agg project set - error
sql: |
create table t (v1 int, v2 int);
select min(v1), unnest(array[2, v2]) from t;
planner_error: 'Invalid input syntax: column must appear in the GROUP BY clause
or be used in an aggregate function'
- name: post-agg project set - grouped
sql: |
create table t (v1 int, v2 int);
select min(v1), unnest(array[2, v2]) from t group by v2;
logical_plan: |
LogicalProject { exprs: [min(t.v1), Unnest(Array(2:Int32, $0))] }
└─LogicalProjectSet { select_list: [$1, Unnest(Array(2:Int32, $0))] }
└─LogicalAgg { group_key: [t.v2], aggs: [min(t.v1)] }
└─LogicalProject { exprs: [t.v2, t.v1] }
└─LogicalScan { table: t, columns: [t.v1, t.v2, t._row_id] }
16 changes: 8 additions & 8 deletions src/frontend/planner_test/tests/testdata/array.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,13 +72,14 @@
└─BatchValues { rows: [[]] }
- sql: |
select array_cat(array[233], array[array[array[66]]]);
binder_error: "Bind error: unable to find least restrictive type between integer[] and integer[][][]"
binder_error: 'Bind error: unable to find least restrictive type between integer[]
and integer[][][]'
- sql: |
select array_cat(array[233], 123);
binder_error: "Bind error: Cannot concatenate integer[] and integer"
binder_error: 'Bind error: Cannot concatenate integer[] and integer'
- sql: |
select array_cat(123, array[233]);
binder_error: "Bind error: Cannot concatenate integer and integer[]"
binder_error: 'Bind error: Cannot concatenate integer and integer[]'
- sql: |
select array_append(array[66], 123);
logical_plan: |
Expand All @@ -89,7 +90,7 @@
└─BatchValues { rows: [[]] }
- sql: |
select array_append(123, 234);
binder_error: "Bind error: Cannot append integer to integer"
binder_error: 'Bind error: Cannot append integer to integer'
- sql: |
/* Combining multidimensional arrays as such is supported beyond what PostgresSQL allows */
select array_append(array[array[66]], array[233]);
Expand All @@ -106,7 +107,7 @@
└─BatchValues { rows: [[]] }
- sql: |
select array_prepend(123, 234);
binder_error: "Bind error: Cannot prepend integer to integer"
binder_error: 'Bind error: Cannot prepend integer to integer'
- sql: |
select array_prepend(array[233], array[array[66]]);
logical_plan: |
Expand All @@ -115,9 +116,8 @@
- name: string from/to varchar[] in implicit context
sql: |
values (array['a', 'b']), ('{c,' || 'd}');
binder_error:
"Bind error: types List { datatype: Varchar } and Varchar cannot be
matched"
binder_error: 'Bind error: types List { datatype: Varchar } and Varchar cannot be
matched'
- name: string to varchar[] in assign context
sql: |
create table t (v1 varchar[]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,3 +684,19 @@
LogicalJoin { type: LeftOuter, on: (t2._row_id = t1._row_id), output: [t1.x, t2.y] }
├─LogicalScan { table: t1, columns: [t1.x, t1._row_id] }
└─LogicalScan { table: t2, columns: [t2.y, t2._row_id] }
- name: issue 4762 correlated input in agg order by
sql: |
CREATE TABLE strings(v1 VARCHAR);
SELECT (SELECT STRING_AGG(v1, ',' ORDER BY t.v1) FROM strings) FROM strings AS t;
optimized_logical_plan: |
LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(strings.v1, strings.v1), output: [string_agg(strings.v1, ',':Varchar order_by(strings.v1 ASC NULLS LAST))] }
├─LogicalScan { table: strings, columns: [strings.v1] }
└─LogicalAgg { group_key: [strings.v1], aggs: [string_agg(strings.v1, ',':Varchar order_by(strings.v1 ASC NULLS LAST))] }
└─LogicalJoin { type: LeftOuter, on: IsNotDistinctFrom(strings.v1, strings.v1), output: [strings.v1, strings.v1, ',':Varchar, strings.v1] }
├─LogicalAgg { group_key: [strings.v1], aggs: [] }
| └─LogicalScan { table: strings, columns: [strings.v1] }
└─LogicalProject { exprs: [strings.v1, strings.v1, ',':Varchar, strings.v1] }
└─LogicalJoin { type: Inner, on: true, output: all }
├─LogicalAgg { group_key: [strings.v1], aggs: [] }
| └─LogicalScan { table: strings, columns: [strings.v1] }
└─LogicalScan { table: strings, columns: [strings.v1] }
16 changes: 16 additions & 0 deletions src/frontend/src/expr/agg_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,22 @@ impl AggCall {
pub fn inputs_mut(&mut self) -> &mut [ExprImpl] {
self.inputs.as_mut()
}

pub fn order_by(&self) -> &OrderBy {
&self.order_by
}

pub fn order_by_mut(&mut self) -> &mut OrderBy {
&mut self.order_by
}

pub fn filter(&self) -> &Condition {
&self.filter
}

pub fn filter_mut(&mut self) -> &mut Condition {
&mut self.filter
}
}

impl Expr for AggCall {
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/src/expr/expr_mutator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ pub trait ExprMutator {
agg_call
.inputs_mut()
.iter_mut()
.for_each(|expr| self.visit_expr(expr))
.for_each(|expr| self.visit_expr(expr));
agg_call.order_by_mut().visit_expr_mut(self);
agg_call.filter_mut().visit_expr_mut(self);
}
fn visit_literal(&mut self, _: &mut Literal) {}
fn visit_input_ref(&mut self, _: &mut InputRef) {}
Expand Down
7 changes: 5 additions & 2 deletions src/frontend/src/expr/expr_visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,15 @@ pub trait ExprVisitor<R: Default> {
.unwrap_or_default()
}
fn visit_agg_call(&mut self, agg_call: &AggCall) -> R {
agg_call
let mut r = agg_call
.inputs()
.iter()
.map(|expr| self.visit_expr(expr))
.reduce(Self::merge)
.unwrap_or_default()
.unwrap_or_default();
r = Self::merge(r, agg_call.order_by().visit_expr(self));
r = Self::merge(r, agg_call.filter().visit_expr(self));
r
}
fn visit_literal(&mut self, _: &Literal) -> R {
R::default()
Expand Down
2 changes: 1 addition & 1 deletion src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ pub use window_function::{WindowFunction, WindowFunctionType};

pub type ExprType = risingwave_pb::expr::expr_node::Type;

pub use expr_mutator::ExprMutator;
pub use expr_rewriter::ExprRewriter;
pub use expr_visitor::ExprVisitor;
pub use type_inference::{
Expand Down Expand Up @@ -716,7 +717,6 @@ macro_rules! assert_eq_input_ref {
pub(crate) use assert_eq_input_ref;
use risingwave_common::catalog::Schema;

use crate::expr::expr_mutator::ExprMutator;
use crate::utils::Condition;

#[cfg(test)]
Expand Down
16 changes: 15 additions & 1 deletion src/frontend/src/expr/order_by_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::fmt::Display;

use itertools::Itertools;

use crate::expr::{ExprImpl, ExprRewriter};
use crate::expr::{ExprImpl, ExprMutator, ExprRewriter, ExprVisitor};
use crate::optimizer::property::Direction;

/// A sort expression in the `ORDER BY` clause.
Expand Down Expand Up @@ -78,4 +78,18 @@ impl OrderBy {
.collect(),
}
}

pub fn visit_expr<R: Default, V: ExprVisitor<R> + ?Sized>(&self, visitor: &mut V) -> R {
self.sort_exprs
.iter()
.map(|expr| visitor.visit_expr(&expr.expr))
.reduce(V::merge)
.unwrap_or_default()
}

pub fn visit_expr_mut(&mut self, mutator: &mut (impl ExprMutator + ?Sized)) {
self.sort_exprs
.iter_mut()
.for_each(|expr| mutator.visit_expr(&mut expr.expr))
}
}
12 changes: 0 additions & 12 deletions src/frontend/src/optimizer/plan_node/logical_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1042,18 +1042,6 @@ impl LogicalAgg {
having: Option<ExprImpl>,
input: PlanRef,
) -> Result<(PlanRef, Vec<ExprImpl>, Option<ExprImpl>)> {
if select_exprs
.iter()
.chain(group_exprs.iter())
.any(|e| e.has_table_function())
{
return Err(ErrorCode::NotImplemented(
"Table functions in agg call or group by is not supported yet".to_string(),
3814.into(),
)
.into());
}

let mut agg_builder = LogicalAggBuilder::new(group_exprs)?;

let rewritten_select_exprs = select_exprs
Expand Down
15 changes: 12 additions & 3 deletions src/frontend/src/utils/condition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ use risingwave_common::util::scan_range::{is_full_range, ScanRange};

use crate::expr::{
factorization_expr, fold_boolean_constant, push_down_not, to_conjunctions,
try_get_bool_constant, ExprDisplay, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef,
try_get_bool_constant, ExprDisplay, ExprImpl, ExprMutator, ExprRewriter, ExprType, ExprVisitor,
InputRef,
};

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -562,10 +563,18 @@ impl Condition {
}
}

pub fn visit_expr(&self, visitor: &mut impl ExprVisitor<()>) {
pub fn visit_expr<R: Default, V: ExprVisitor<R> + ?Sized>(&self, visitor: &mut V) -> R {
self.conjunctions
.iter()
.for_each(|expr| visitor.visit_expr(expr))
.map(|expr| visitor.visit_expr(expr))
.reduce(V::merge)
.unwrap_or_default()
}

pub fn visit_expr_mut(&mut self, mutator: &mut (impl ExprMutator + ?Sized)) {
self.conjunctions
.iter_mut()
.for_each(|expr| mutator.visit_expr(expr))
}

/// Simplify conditions
Expand Down

0 comments on commit 1178dbe

Please sign in to comment.