diff --git a/src/frontend/src/binder/expr/function.rs b/src/frontend/src/binder/expr/function.rs index 4930c5a367f90..a6a69934e4134 100644 --- a/src/frontend/src/binder/expr/function.rs +++ b/src/frontend/src/binder/expr/function.rs @@ -23,6 +23,7 @@ use risingwave_sqlparser::ast::{Function, FunctionArg, FunctionArgExpr}; use crate::binder::bind_context::Clause; use crate::binder::Binder; use crate::expr::{AggCall, Expr, ExprImpl, ExprType, FunctionCall, Literal}; +use crate::utils::Condition; impl Binder { pub(super) fn bind_function(&mut self, f: Function) -> Result { @@ -49,9 +50,42 @@ impl Binder { }; if let Some(kind) = agg_kind { self.ensure_aggregate_allowed()?; + let filter = match f.filter { + Some(filter) => { + let expr = self.bind_expr(*filter)?; + if expr.return_type() != DataType::Boolean { + return Err(ErrorCode::InvalidInputSyntax(format!( + "the type of filter clause should be boolean, but found {:?}", + expr.return_type() + )) + .into()); + } + if expr.has_subquery() { + return Err(ErrorCode::InvalidInputSyntax( + "subquery in filter clause is not supported".to_string(), + ) + .into()); + } + if expr.has_agg_call() { + return Err(ErrorCode::InvalidInputSyntax( + "aggregation function in filter clause is not supported" + .to_string(), + ) + .into()); + } + Condition::with_expr(expr) + } + None => Condition::true_cond(), + }; return Ok(ExprImpl::AggCall(Box::new(AggCall::new( - kind, inputs, f.distinct, + kind, inputs, f.distinct, filter, )?))); + } else if f.filter.is_some() { + return Err(ErrorCode::InvalidInputSyntax(format!( + "filter clause is only allowed in aggregation functions, but `{}` is not an aggregation function", function_name + ) + ) + .into()); } let function_type = match function_name.as_str() { // comparison diff --git a/src/frontend/src/expr/agg_call.rs b/src/frontend/src/expr/agg_call.rs index dc1838cfe79d4..0df7aaf0b84e0 100644 --- a/src/frontend/src/expr/agg_call.rs +++ b/src/frontend/src/expr/agg_call.rs @@ -18,6 +18,7 @@ use risingwave_common::types::DataType; use risingwave_expr::expr::AggKind; use super::{Expr, ExprImpl}; +use crate::utils::Condition; #[derive(Clone, Eq, PartialEq, Hash)] pub struct AggCall { @@ -25,6 +26,7 @@ pub struct AggCall { return_type: DataType, inputs: Vec, distinct: bool, + filter: Condition, } impl std::fmt::Debug for AggCall { @@ -34,6 +36,7 @@ impl std::fmt::Debug for AggCall { .field("agg_kind", &self.agg_kind) .field("return_type", &self.return_type) .field("inputs", &self.inputs) + .field("filter", &self.filter) .finish() } else { let mut builder = f.debug_tuple(&format!("{}", self.agg_kind)); @@ -114,7 +117,12 @@ impl AggCall { /// Returns error if the function name matches with an existing function /// but with illegal arguments. - pub fn new(agg_kind: AggKind, inputs: Vec, distinct: bool) -> Result { + pub fn new( + agg_kind: AggKind, + inputs: Vec, + distinct: bool, + filter: Condition, + ) -> Result { let data_types = inputs.iter().map(ExprImpl::return_type).collect_vec(); let return_type = Self::infer_return_type(&agg_kind, &data_types)?; Ok(AggCall { @@ -122,11 +130,12 @@ impl AggCall { return_type, inputs, distinct, + filter, }) } - pub fn decompose(self) -> (AggKind, Vec, bool) { - (self.agg_kind, self.inputs, self.distinct) + pub fn decompose(self) -> (AggKind, Vec, bool, Condition) { + (self.agg_kind, self.inputs, self.distinct, self.filter) } pub fn agg_kind(&self) -> AggKind { diff --git a/src/frontend/src/expr/expr_rewriter.rs b/src/frontend/src/expr/expr_rewriter.rs index 8a4be1d164442..4ba33cbe1dd6c 100644 --- a/src/frontend/src/expr/expr_rewriter.rs +++ b/src/frontend/src/expr/expr_rewriter.rs @@ -37,12 +37,15 @@ pub trait ExprRewriter { FunctionCall::new_unchecked(func_type, inputs, ret).into() } fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl { - let (func_type, inputs, distinct) = agg_call.decompose(); + let (func_type, inputs, distinct, filter) = agg_call.decompose(); let inputs = inputs .into_iter() .map(|expr| self.rewrite_expr(expr)) .collect(); - AggCall::new(func_type, inputs, distinct).unwrap().into() + let filter = filter.rewrite_expr(self); + AggCall::new(func_type, inputs, distinct, filter) + .unwrap() + .into() } fn rewrite_literal(&mut self, literal: Literal) -> ExprImpl { literal.into() diff --git a/src/frontend/src/expr/mod.rs b/src/frontend/src/expr/mod.rs index 4931d14d1f48d..665ee3610204e 100644 --- a/src/frontend/src/expr/mod.rs +++ b/src/frontend/src/expr/mod.rs @@ -91,7 +91,9 @@ impl ExprImpl { /// A `count(*)` aggregate function. #[inline(always)] pub fn count_star() -> Self { - AggCall::new(AggKind::Count, vec![], false).unwrap().into() + AggCall::new(AggKind::Count, vec![], false, Condition::true_cond()) + .unwrap() + .into() } /// Collect all `InputRef`s' indexes in the expression. diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 142f0763c0a33..cc9f4463870de 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -52,6 +52,10 @@ pub struct PlanAggCall { pub inputs: Vec, pub distinct: bool, + /// Selective aggregation: only the input rows for which + /// the filter_clause evaluates to true will be fed to aggregate function. + /// Other rows are discarded. + pub filter: Condition, } impl fmt::Debug for PlanAggCall { @@ -70,6 +74,13 @@ impl fmt::Debug for PlanAggCall { } write!(f, ")")?; } + if !self.filter.always_true() { + write!( + f, + " filter({:?})", + self.filter.as_expr_unless_true().unwrap() + )?; + } Ok(()) } } @@ -99,6 +110,7 @@ impl PlanAggCall { PlanAggCall { agg_kind: total_agg_kind, inputs: vec![InputRef::new(partial_output_idx, self.return_type.clone())], + filter: Condition::true_cond(), ..self.clone() } } @@ -109,6 +121,7 @@ impl PlanAggCall { return_type: DataType::Int64, inputs: vec![], distinct: false, + filter: Condition::true_cond(), } } } @@ -232,6 +245,12 @@ struct LogicalAggBuilder { agg_calls: Vec, /// the error during the expression rewriting error: Option, + /// If `is_in_filter_clause` is true, it means that + /// we are processing a filter clause. + /// This field is needed because input refs in filter clause + /// are allowed to refer to any columns, while those not in filter + /// clause are only allowed to refer to group keys. + is_in_filter_clause: bool, } impl LogicalAggBuilder { @@ -257,6 +276,7 @@ impl LogicalAggBuilder { agg_calls: vec![], error: None, input_proj_builder, + is_in_filter_clause: false, }) } @@ -300,14 +320,18 @@ impl ExprRewriter for LogicalAggBuilder { /// Note that the rewriter does not traverse into inputs of agg calls. fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl { let return_type = agg_call.return_type(); - let (agg_kind, inputs, distinct) = agg_call.decompose(); - + let (agg_kind, inputs, distinct, filter) = agg_call.decompose(); + self.is_in_filter_clause = true; + let filter = filter.rewrite_expr(self); + self.is_in_filter_clause = false; for i in &inputs { if i.has_agg_call() { self.error = Some(ErrorCode::InvalidInputSyntax( "Aggregation calls should not be nested".into(), )); - return AggCall::new(agg_kind, inputs, distinct).unwrap().into(); + return AggCall::new(agg_kind, inputs, distinct, filter) + .unwrap() + .into(); } } @@ -331,6 +355,7 @@ impl ExprRewriter for LogicalAggBuilder { return_type: left_return_type.clone(), inputs: inputs.clone(), distinct, + filter: filter.clone(), }); let left = ExprImpl::from(InputRef::new( self.group_keys.len() + self.agg_calls.len() - 1, @@ -347,6 +372,7 @@ impl ExprRewriter for LogicalAggBuilder { return_type: right_return_type.clone(), inputs, distinct, + filter, }); let right = InputRef::new( @@ -361,6 +387,7 @@ impl ExprRewriter for LogicalAggBuilder { return_type: return_type.clone(), inputs, distinct, + filter, }); ExprImpl::from(InputRef::new( self.group_keys.len() + self.agg_calls.len() - 1, @@ -390,6 +417,8 @@ impl ExprRewriter for LogicalAggBuilder { let expr = input_ref.into(); if let Some(group_key) = self.try_as_group_expr(&expr) { InputRef::new(group_key, expr.return_type()).into() + } else if self.is_in_filter_clause { + InputRef::new(self.input_proj_builder.add_expr(&expr), expr.return_type()).into() } else { self.error = Some(ErrorCode::InvalidInputSyntax( "column must appear in the GROUP BY clause or be used in an aggregate function" @@ -540,7 +569,7 @@ impl PlanTreeNodeUnary for LogicalAgg { fn rewrite_with_input( &self, input: PlanRef, - input_col_change: ColIndexMapping, + mut input_col_change: ColIndexMapping, ) -> (Self, ColIndexMapping) { let agg_calls = self .agg_calls @@ -550,6 +579,7 @@ impl PlanTreeNodeUnary for LogicalAgg { agg_call.inputs.iter_mut().for_each(|i| { *i = InputRef::new(input_col_change.map(i.index()), i.return_type()) }); + agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change); agg_call }) .collect(); @@ -590,7 +620,8 @@ impl ColPrunable for LogicalAgg { let group_key_required_cols = FixedBitSet::from_iter(self.group_keys.iter().copied()); let (agg_call_required_cols, agg_calls) = { - let mut tmp = FixedBitSet::with_capacity(self.input().schema().fields().len()); + let input_cnt = self.input().schema().fields().len(); + let mut tmp = FixedBitSet::with_capacity(input_cnt); let new_agg_calls = required_cols .iter() .filter(|&&index| index >= self.group_keys.len()) @@ -598,6 +629,10 @@ impl ColPrunable for LogicalAgg { let index = index - self.group_keys.len(); let agg_call = self.agg_calls[index].clone(); tmp.extend(agg_call.inputs.iter().map(|x| x.index())); + // collect columns used in aggregate filter expressions + for i in &agg_call.filter.conjunctions { + tmp.union_with(&i.collect_input_refs(input_cnt)); + } agg_call }) .collect_vec(); @@ -845,8 +880,13 @@ mod tests { // Test case: select v1, min(v2) from test group by v1; { - let min_v2 = - AggCall::new(AggKind::Min, vec![input_ref_2.clone().into()], false).unwrap(); + let min_v2 = AggCall::new( + AggKind::Min, + vec![input_ref_2.clone().into()], + false, + Condition::true_cond(), + ) + .unwrap(); let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()]; let group_exprs = vec![input_ref_1.clone().into()]; @@ -864,10 +904,20 @@ mod tests { // Test case: select v1, min(v2) + max(v3) from t group by v1; { - let min_v2 = - AggCall::new(AggKind::Min, vec![input_ref_2.clone().into()], false).unwrap(); - let max_v3 = - AggCall::new(AggKind::Max, vec![input_ref_3.clone().into()], false).unwrap(); + let min_v2 = AggCall::new( + AggKind::Min, + vec![input_ref_2.clone().into()], + false, + Condition::true_cond(), + ) + .unwrap(); + let max_v3 = AggCall::new( + AggKind::Max, + vec![input_ref_3.clone().into()], + false, + Condition::true_cond(), + ) + .unwrap(); let func_call = FunctionCall::new(ExprType::Add, vec![min_v2.into(), max_v3.into()]).unwrap(); let select_exprs = vec![input_ref_1.clone().into(), ExprImpl::from(func_call)]; @@ -900,7 +950,13 @@ mod tests { vec![input_ref_1.into(), input_ref_3.into()], ) .unwrap(); - let agg_call = AggCall::new(AggKind::Min, vec![v1_mult_v3.into()], false).unwrap(); + let agg_call = AggCall::new( + AggKind::Min, + vec![v1_mult_v3.into()], + false, + Condition::true_cond(), + ) + .unwrap(); let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()]; let group_exprs = vec![input_ref_2.into()]; @@ -930,6 +986,7 @@ mod tests { return_type: ty.clone(), inputs: vec![InputRef::new(2, ty.clone())], distinct: false, + filter: Condition::true_cond(), }; LogicalAgg::new(vec![agg_call], vec![1], values.into()) } @@ -1047,6 +1104,7 @@ mod tests { return_type: ty.clone(), inputs: vec![InputRef::new(2, ty.clone())], distinct: false, + filter: Condition::true_cond(), }; let agg = LogicalAgg::new(vec![agg_call], vec![1], values.into()); @@ -1110,12 +1168,14 @@ mod tests { return_type: ty.clone(), inputs: vec![InputRef::new(2, ty.clone())], distinct: false, + filter: Condition::true_cond(), }, PlanAggCall { agg_kind: AggKind::Max, return_type: ty.clone(), inputs: vec![InputRef::new(1, ty.clone())], distinct: false, + filter: Condition::true_cond(), }, ]; let agg = LogicalAgg::new(agg_calls, vec![1, 2], values.into()); diff --git a/src/frontend/src/utils/condition.rs b/src/frontend/src/utils/condition.rs index 2878476284c31..d39401f822d7c 100644 --- a/src/frontend/src/utils/condition.rs +++ b/src/frontend/src/utils/condition.rs @@ -25,7 +25,7 @@ use crate::expr::{ try_get_bool_constant, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef, }; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Condition { /// Condition expressions in conjunction form (combined with `AND`) pub conjunctions: Vec, @@ -379,7 +379,7 @@ impl Condition { } #[must_use] - pub fn rewrite_expr(self, rewriter: &mut impl ExprRewriter) -> Self { + pub fn rewrite_expr(self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self { Self { conjunctions: self .conjunctions diff --git a/src/frontend/test_runner/tests/testdata/agg.yaml b/src/frontend/test_runner/tests/testdata/agg.yaml index c2945cdaf8319..d9b39864ee624 100644 --- a/src/frontend/test_runner/tests/testdata/agg.yaml +++ b/src/frontend/test_runner/tests/testdata/agg.yaml @@ -371,3 +371,113 @@ StreamExchange { dist: HashShard([0, 1]) } StreamProject { exprs: [$2, $1, $0, $3] } StreamTableScan { table: t, columns: [v1, v2, v3, _row_id], pk_indices: [3] } +- sql: | + /* filter clause */ + create table t(v1 int); + select sum(v1) FILTER (WHERE v1 > 0) AS sa from t; + logical_plan: | + LogicalProject { exprs: [$0] } + LogicalAgg { group_keys: [], agg_calls: [sum($0) filter(($0 > 0:Int32))] } + LogicalProject { exprs: [$1] } + LogicalScan { table: t, columns: [_row_id, v1] } + optimized_logical_plan: | + LogicalAgg { group_keys: [], agg_calls: [sum($0) filter(($0 > 0:Int32))] } + LogicalScan { table: t, columns: [v1] } + stream_plan: | + StreamMaterialize { columns: [agg#0(hidden), sa], pk_columns: [] } + StreamGlobalSimpleAgg { aggs: [sum($0), sum($1)] } + StreamExchange { dist: Single } + StreamLocalSimpleAgg { aggs: [count, sum($0) filter(($0 > 0:Int32))] } + StreamTableScan { table: t, columns: [v1, _row_id], pk_indices: [1] } +- sql: | + /* filter clause */ + /* extra calculation, should reuse result from project */ + create table t(a int, b int); + select sum(a * b) filter (where a * b > 0) as sab from t; + logical_plan: | + LogicalProject { exprs: [$0] } + LogicalAgg { group_keys: [], agg_calls: [sum($2) filter((($0 * $1) > 0:Int32))] } + LogicalProject { exprs: [$1, $2, ($1 * $2)] } + LogicalScan { table: t, columns: [_row_id, a, b] } + optimized_logical_plan: | + LogicalAgg { group_keys: [], agg_calls: [sum($2) filter((($0 * $1) > 0:Int32))] } + LogicalProject { exprs: [$0, $1, ($0 * $1)] } + LogicalScan { table: t, columns: [a, b] } +- sql: | + /* complex filter clause */ + create table t(a int, b int); + select max(a * b) FILTER (WHERE a < b AND a + b < 100 AND a * b != a + b - 1) AS sab from t; + logical_plan: | + LogicalProject { exprs: [$0] } + LogicalAgg { group_keys: [], agg_calls: [max($2) filter(((($0 < $1) AND (($0 + $1) < 100:Int32)) AND (($0 * $1) <> (($0 + $1) - 1:Int32))))] } + LogicalProject { exprs: [$1, $2, ($1 * $2)] } + LogicalScan { table: t, columns: [_row_id, a, b] } + optimized_logical_plan: | + LogicalAgg { group_keys: [], agg_calls: [max($2) filter(((($0 < $1) AND (($0 + $1) < 100:Int32)) AND (($0 * $1) <> (($0 + $1) - 1:Int32))))] } + LogicalProject { exprs: [$0, $1, ($0 * $1)] } + LogicalScan { table: t, columns: [a, b] } + stream_plan: | + StreamMaterialize { columns: [agg#0(hidden), sab], pk_columns: [] } + StreamGlobalSimpleAgg { aggs: [count, max($2) filter(((($0 < $1) AND (($0 + $1) < 100:Int32)) AND (($0 * $1) <> (($0 + $1) - 1:Int32))))] } + StreamExchange { dist: Single } + StreamProject { exprs: [$0, $1, ($0 * $1), $2] } + StreamTableScan { table: t, columns: [a, b, _row_id], pk_indices: [2] } +- sql: | + /* avg filter clause + group by */ + create table t(a int, b int); + select avg(a) FILTER (WHERE a > b) AS avga from t group by b ; + logical_plan: | + LogicalProject { exprs: [($1::Decimal / $2)] } + LogicalAgg { group_keys: [0], agg_calls: [sum($1) filter(($1 > $0)), count($1) filter(($1 > $0))] } + LogicalProject { exprs: [$2, $1] } + LogicalScan { table: t, columns: [_row_id, a, b] } + optimized_logical_plan: | + LogicalProject { exprs: [($1::Decimal / $2)] } + LogicalAgg { group_keys: [0], agg_calls: [sum($1) filter(($1 > $0)), count($1) filter(($1 > $0))] } + LogicalProject { exprs: [$1, $0] } + LogicalScan { table: t, columns: [a, b] } + stream_plan: | + StreamMaterialize { columns: [avga, b(hidden)], pk_columns: [b] } + StreamProject { exprs: [($2::Decimal / $3), $0] } + StreamHashAgg { group_keys: [$0], aggs: [count, sum($1) filter(($1 > $0)), count($1) filter(($1 > $0))] } + StreamExchange { dist: HashShard([0]) } + StreamProject { exprs: [$1, $0, $2] } + StreamTableScan { table: t, columns: [a, b, _row_id], pk_indices: [2] } +- sql: | + /* count filter clause */ + create table t(a int, b int); + select count(*) FILTER (WHERE a > b) AS cnt_agb from t; + logical_plan: | + LogicalProject { exprs: [$0] } + LogicalAgg { group_keys: [], agg_calls: [count filter(($0 > $1))] } + LogicalProject { exprs: [$1, $2] } + LogicalScan { table: t, columns: [_row_id, a, b] } + optimized_logical_plan: | + LogicalAgg { group_keys: [], agg_calls: [count filter(($0 > $1))] } + LogicalScan { table: t, columns: [a, b] } + stream_plan: | + StreamMaterialize { columns: [agg#0(hidden), cnt_agb], pk_columns: [] } + StreamGlobalSimpleAgg { aggs: [sum($0), sum($1)] } + StreamExchange { dist: Single } + StreamLocalSimpleAgg { aggs: [count, count filter(($0 > $1))] } + StreamTableScan { table: t, columns: [a, b, _row_id], pk_indices: [2] } +- sql: | + /* filter clause + non-boolean function */ + create table t(a int, b int); + select avg(a) FILTER (WHERE abs(a)) AS avga from t; + binder_error: 'Invalid input syntax: the type of filter clause should be boolean, but found Int32' +- sql: | + /* filter clause + subquery */ + create table t(a int, b int); + select avg(a) FILTER (WHERE 0 < (select max(a) from t)) AS avga from t; + binder_error: 'Invalid input syntax: subquery in filter clause is not supported' +- sql: | + /* aggregation in filter clause */ + create table t(a int, b int); + select avg(a) FILTER (WHERE a < avg(b)) AS avga from t; + binder_error: 'Invalid input syntax: aggregation function in filter clause is not supported' +- sql: | + /* filter clause + non-boolean function */ + create table t(a int, b int); + select abs(a) FILTER (WHERE a > 0) AS avga from t; + binder_error: 'Invalid input syntax: filter clause is only allowed in aggregation functions, but `abs` is not an aggregation function' diff --git a/src/sqlparser/src/ast/mod.rs b/src/sqlparser/src/ast/mod.rs index b5b066322e822..90b22bb558965 100644 --- a/src/sqlparser/src/ast/mod.rs +++ b/src/sqlparser/src/ast/mod.rs @@ -1540,6 +1540,7 @@ pub struct Function { pub distinct: bool, // string_agg and array_agg both support ORDER BY pub order_by: Vec, + pub filter: Option>, } impl fmt::Display for Function { @@ -1560,6 +1561,9 @@ impl fmt::Display for Function { if let Some(o) = &self.over { write!(f, " OVER ({})", o)?; } + if let Some(filter) = &self.filter { + write!(f, " FILTER(WHERE {})", filter)?; + } Ok(()) } } diff --git a/src/sqlparser/src/parser.rs b/src/sqlparser/src/parser.rs index 377994a3c3ea6..7f3d9c3bc9e42 100644 --- a/src/sqlparser/src/parser.rs +++ b/src/sqlparser/src/parser.rs @@ -593,12 +593,23 @@ impl Parser { None }; + let filter = if self.parse_keyword(Keyword::FILTER) { + self.expect_token(&Token::LParen)?; + self.expect_keyword(Keyword::WHERE)?; + let filter_expr = self.parse_expr()?; + self.expect_token(&Token::RParen)?; + Some(Box::new(filter_expr)) + } else { + None + }; + Ok(Expr::Function(Function { name, args, over, distinct, order_by, + filter, })) } diff --git a/src/sqlparser/tests/sqlparser_common.rs b/src/sqlparser/tests/sqlparser_common.rs index d2579452f14f7..13b8b10de61e5 100644 --- a/src/sqlparser/tests/sqlparser_common.rs +++ b/src/sqlparser/tests/sqlparser_common.rs @@ -378,6 +378,7 @@ fn parse_select_count_wildcard() { over: None, distinct: false, order_by: vec![], + filter: None }), expr_from_projection(only(&select.projection)) ); @@ -397,6 +398,7 @@ fn parse_select_count_distinct() { over: None, distinct: true, order_by: vec![], + filter: None }), expr_from_projection(only(&select.projection)) ); @@ -1093,9 +1095,10 @@ fn parse_select_having() { over: None, distinct: false, order_by: vec![], + filter: None })), op: BinaryOperator::Gt, - right: Box::new(Expr::Value(number("1"))) + right: Box::new(Expr::Value(number("1"))), }), select.having ); @@ -1786,6 +1789,7 @@ fn parse_named_argument_function() { over: None, distinct: false, order_by: vec![], + filter: None, }), expr_from_projection(only(&select.projection)) ); @@ -1820,6 +1824,7 @@ fn parse_window_functions() { }), distinct: false, order_by: vec![], + filter: None, }), expr_from_projection(&select.projection[0]) ); @@ -1857,11 +1862,41 @@ fn parse_aggregate_with_order_by() { nulls_first: None, } ], + filter: None, }), expr_from_projection(only(&select.projection)) ); } +#[test] +fn parse_aggregate_with_filter() { + let sql = "SELECT sum(a) FILTER(WHERE (a > 0) AND (a IS NOT NULL)) FROM foo"; + let select = verified_only_select(sql); + assert_eq!( + &Expr::Function(Function { + name: ObjectName(vec![Ident::new("sum")]), + args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr( + Expr::Identifier(Ident::new("a")) + )),], + over: None, + distinct: false, + order_by: vec![], + filter: Some(Box::new(Expr::BinaryOp { + left: Box::new(Expr::Nested(Box::new(Expr::BinaryOp { + left: Box::new(Expr::Identifier(Ident::new("a"))), + op: BinaryOperator::Gt, + right: Box::new(Expr::Value(Value::Number("0".to_string(), false))) + }))), + op: BinaryOperator::And, + right: Box::new(Expr::Nested(Box::new(Expr::IsNotNull(Box::new( + Expr::Identifier(Ident::new("a")) + ))))) + })), + }), + expr_from_projection(only(&select.projection)), + ); +} + #[test] fn parse_literal_decimal() { // These numbers were explicitly chosen to not roundtrip if represented as @@ -2095,6 +2130,7 @@ fn parse_delimited_identifiers() { over: None, distinct: false, order_by: vec![], + filter: None, }), expr_from_projection(&select.projection[1]), ); diff --git a/src/sqlparser/tests/testdata/select.yaml b/src/sqlparser/tests/testdata/select.yaml index d35057893d14b..c48fb5b9d165f 100644 --- a/src/sqlparser/tests/testdata/select.yaml +++ b/src/sqlparser/tests/testdata/select.yaml @@ -1,7 +1,7 @@ - input: SELECT sqrt(id) FROM foo formatted_sql: SELECT sqrt(id) FROM foo formatted_ast: | - Query(Query { with: None, body: Select(Select { distinct: false, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "sqrt", quote_style: None }]), args: [Unnamed(Expr(Identifier(Ident { value: "id", quote_style: None })))], over: None, distinct: false, order_by: [] }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "foo", quote_style: None }]), alias: None, args: [] }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None }) + Query(Query { with: None, body: Select(Select { distinct: false, projection: [UnnamedExpr(Function(Function { name: ObjectName([Ident { value: "sqrt", quote_style: None }]), args: [Unnamed(Expr(Identifier(Ident { value: "id", quote_style: None })))], over: None, distinct: false, order_by: [], filter: None }))], from: [TableWithJoins { relation: Table { name: ObjectName([Ident { value: "foo", quote_style: None }]), alias: None, args: [] }, joins: [] }], lateral_views: [], selection: None, group_by: [], having: None }), order_by: [], limit: None, offset: None, fetch: None }) # Typed string literal - input: SELECT INT '1' diff --git a/src/tests/sqlsmith/src/expr.rs b/src/tests/sqlsmith/src/expr.rs index fdcf596e240a2..cd11eaf64ab8a 100644 --- a/src/tests/sqlsmith/src/expr.rs +++ b/src/tests/sqlsmith/src/expr.rs @@ -165,6 +165,7 @@ fn make_func(func_name: &str, exprs: &[Expr]) -> Function { over: None, distinct: false, order_by: vec![], + filter: None, } }