Skip to content

Commit

Permalink
feat(frontend): support filter clause in aggregation (risingwavelabs#…
Browse files Browse the repository at this point in the history
…3626)

* parse filter in agg

* bind selective agg

* fix filter clause input ref index

* fix logical agg column prune

* modify parser test, add filter

* fmt code

* remove unused struct

* add more tests and error reporting

* add test for non-agg filter clause

* add group by test

* fix `rewrite_with_input`

* fix 2 phase agg

* report error for agg calls in the filter clause

Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and nasnoisaac committed Aug 9, 2022
1 parent 4df1469 commit 654f0b1
Show file tree
Hide file tree
Showing 12 changed files with 293 additions and 23 deletions.
36 changes: 35 additions & 1 deletion src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use risingwave_sqlparser::ast::{Function, FunctionArg, FunctionArgExpr};
use crate::binder::bind_context::Clause;
use crate::binder::Binder;
use crate::expr::{AggCall, Expr, ExprImpl, ExprType, FunctionCall, Literal};
use crate::utils::Condition;

impl Binder {
pub(super) fn bind_function(&mut self, f: Function) -> Result<ExprImpl> {
Expand All @@ -49,9 +50,42 @@ impl Binder {
};
if let Some(kind) = agg_kind {
self.ensure_aggregate_allowed()?;
let filter = match f.filter {
Some(filter) => {
let expr = self.bind_expr(*filter)?;
if expr.return_type() != DataType::Boolean {
return Err(ErrorCode::InvalidInputSyntax(format!(
"the type of filter clause should be boolean, but found {:?}",
expr.return_type()
))
.into());
}
if expr.has_subquery() {
return Err(ErrorCode::InvalidInputSyntax(
"subquery in filter clause is not supported".to_string(),
)
.into());
}
if expr.has_agg_call() {
return Err(ErrorCode::InvalidInputSyntax(
"aggregation function in filter clause is not supported"
.to_string(),
)
.into());
}
Condition::with_expr(expr)
}
None => Condition::true_cond(),
};
return Ok(ExprImpl::AggCall(Box::new(AggCall::new(
kind, inputs, f.distinct,
kind, inputs, f.distinct, filter,
)?)));
} else if f.filter.is_some() {
return Err(ErrorCode::InvalidInputSyntax(format!(
"filter clause is only allowed in aggregation functions, but `{}` is not an aggregation function", function_name
)
)
.into());
}
let function_type = match function_name.as_str() {
// comparison
Expand Down
15 changes: 12 additions & 3 deletions src/frontend/src/expr/agg_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ use risingwave_common::types::DataType;
use risingwave_expr::expr::AggKind;

use super::{Expr, ExprImpl};
use crate::utils::Condition;

#[derive(Clone, Eq, PartialEq, Hash)]
pub struct AggCall {
agg_kind: AggKind,
return_type: DataType,
inputs: Vec<ExprImpl>,
distinct: bool,
filter: Condition,
}

impl std::fmt::Debug for AggCall {
Expand All @@ -34,6 +36,7 @@ impl std::fmt::Debug for AggCall {
.field("agg_kind", &self.agg_kind)
.field("return_type", &self.return_type)
.field("inputs", &self.inputs)
.field("filter", &self.filter)
.finish()
} else {
let mut builder = f.debug_tuple(&format!("{}", self.agg_kind));
Expand Down Expand Up @@ -114,19 +117,25 @@ impl AggCall {

/// Returns error if the function name matches with an existing function
/// but with illegal arguments.
pub fn new(agg_kind: AggKind, inputs: Vec<ExprImpl>, distinct: bool) -> Result<Self> {
pub fn new(
agg_kind: AggKind,
inputs: Vec<ExprImpl>,
distinct: bool,
filter: Condition,
) -> Result<Self> {
let data_types = inputs.iter().map(ExprImpl::return_type).collect_vec();
let return_type = Self::infer_return_type(&agg_kind, &data_types)?;
Ok(AggCall {
agg_kind,
return_type,
inputs,
distinct,
filter,
})
}

pub fn decompose(self) -> (AggKind, Vec<ExprImpl>, bool) {
(self.agg_kind, self.inputs, self.distinct)
pub fn decompose(self) -> (AggKind, Vec<ExprImpl>, bool, Condition) {
(self.agg_kind, self.inputs, self.distinct, self.filter)
}

pub fn agg_kind(&self) -> AggKind {
Expand Down
7 changes: 5 additions & 2 deletions src/frontend/src/expr/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ pub trait ExprRewriter {
FunctionCall::new_unchecked(func_type, inputs, ret).into()
}
fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
let (func_type, inputs, distinct) = agg_call.decompose();
let (func_type, inputs, distinct, filter) = agg_call.decompose();
let inputs = inputs
.into_iter()
.map(|expr| self.rewrite_expr(expr))
.collect();
AggCall::new(func_type, inputs, distinct).unwrap().into()
let filter = filter.rewrite_expr(self);
AggCall::new(func_type, inputs, distinct, filter)
.unwrap()
.into()
}
fn rewrite_literal(&mut self, literal: Literal) -> ExprImpl {
literal.into()
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ impl ExprImpl {
/// A `count(*)` aggregate function.
#[inline(always)]
pub fn count_star() -> Self {
AggCall::new(AggKind::Count, vec![], false).unwrap().into()
AggCall::new(AggKind::Count, vec![], false, Condition::true_cond())
.unwrap()
.into()
}

/// Collect all `InputRef`s' indexes in the expression.
Expand Down
84 changes: 72 additions & 12 deletions src/frontend/src/optimizer/plan_node/logical_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ pub struct PlanAggCall {
pub inputs: Vec<InputRef>,

pub distinct: bool,
/// Selective aggregation: only the input rows for which
/// the filter_clause evaluates to true will be fed to aggregate function.
/// Other rows are discarded.
pub filter: Condition,
}

impl fmt::Debug for PlanAggCall {
Expand All @@ -70,6 +74,13 @@ impl fmt::Debug for PlanAggCall {
}
write!(f, ")")?;
}
if !self.filter.always_true() {
write!(
f,
" filter({:?})",
self.filter.as_expr_unless_true().unwrap()
)?;
}
Ok(())
}
}
Expand Down Expand Up @@ -99,6 +110,7 @@ impl PlanAggCall {
PlanAggCall {
agg_kind: total_agg_kind,
inputs: vec![InputRef::new(partial_output_idx, self.return_type.clone())],
filter: Condition::true_cond(),
..self.clone()
}
}
Expand All @@ -109,6 +121,7 @@ impl PlanAggCall {
return_type: DataType::Int64,
inputs: vec![],
distinct: false,
filter: Condition::true_cond(),
}
}
}
Expand Down Expand Up @@ -232,6 +245,12 @@ struct LogicalAggBuilder {
agg_calls: Vec<PlanAggCall>,
/// the error during the expression rewriting
error: Option<ErrorCode>,
/// If `is_in_filter_clause` is true, it means that
/// we are processing a filter clause.
/// This field is needed because input refs in filter clause
/// are allowed to refer to any columns, while those not in filter
/// clause are only allowed to refer to group keys.
is_in_filter_clause: bool,
}

impl LogicalAggBuilder {
Expand All @@ -257,6 +276,7 @@ impl LogicalAggBuilder {
agg_calls: vec![],
error: None,
input_proj_builder,
is_in_filter_clause: false,
})
}

Expand Down Expand Up @@ -300,14 +320,18 @@ impl ExprRewriter for LogicalAggBuilder {
/// Note that the rewriter does not traverse into inputs of agg calls.
fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
let return_type = agg_call.return_type();
let (agg_kind, inputs, distinct) = agg_call.decompose();

let (agg_kind, inputs, distinct, filter) = agg_call.decompose();
self.is_in_filter_clause = true;
let filter = filter.rewrite_expr(self);
self.is_in_filter_clause = false;
for i in &inputs {
if i.has_agg_call() {
self.error = Some(ErrorCode::InvalidInputSyntax(
"Aggregation calls should not be nested".into(),
));
return AggCall::new(agg_kind, inputs, distinct).unwrap().into();
return AggCall::new(agg_kind, inputs, distinct, filter)
.unwrap()
.into();
}
}

Expand All @@ -331,6 +355,7 @@ impl ExprRewriter for LogicalAggBuilder {
return_type: left_return_type.clone(),
inputs: inputs.clone(),
distinct,
filter: filter.clone(),
});
let left = ExprImpl::from(InputRef::new(
self.group_keys.len() + self.agg_calls.len() - 1,
Expand All @@ -347,6 +372,7 @@ impl ExprRewriter for LogicalAggBuilder {
return_type: right_return_type.clone(),
inputs,
distinct,
filter,
});

let right = InputRef::new(
Expand All @@ -361,6 +387,7 @@ impl ExprRewriter for LogicalAggBuilder {
return_type: return_type.clone(),
inputs,
distinct,
filter,
});
ExprImpl::from(InputRef::new(
self.group_keys.len() + self.agg_calls.len() - 1,
Expand Down Expand Up @@ -390,6 +417,8 @@ impl ExprRewriter for LogicalAggBuilder {
let expr = input_ref.into();
if let Some(group_key) = self.try_as_group_expr(&expr) {
InputRef::new(group_key, expr.return_type()).into()
} else if self.is_in_filter_clause {
InputRef::new(self.input_proj_builder.add_expr(&expr), expr.return_type()).into()
} else {
self.error = Some(ErrorCode::InvalidInputSyntax(
"column must appear in the GROUP BY clause or be used in an aggregate function"
Expand Down Expand Up @@ -540,7 +569,7 @@ impl PlanTreeNodeUnary for LogicalAgg {
fn rewrite_with_input(
&self,
input: PlanRef,
input_col_change: ColIndexMapping,
mut input_col_change: ColIndexMapping,
) -> (Self, ColIndexMapping) {
let agg_calls = self
.agg_calls
Expand All @@ -550,6 +579,7 @@ impl PlanTreeNodeUnary for LogicalAgg {
agg_call.inputs.iter_mut().for_each(|i| {
*i = InputRef::new(input_col_change.map(i.index()), i.return_type())
});
agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change);
agg_call
})
.collect();
Expand Down Expand Up @@ -590,14 +620,19 @@ impl ColPrunable for LogicalAgg {
let group_key_required_cols = FixedBitSet::from_iter(self.group_keys.iter().copied());

let (agg_call_required_cols, agg_calls) = {
let mut tmp = FixedBitSet::with_capacity(self.input().schema().fields().len());
let input_cnt = self.input().schema().fields().len();
let mut tmp = FixedBitSet::with_capacity(input_cnt);
let new_agg_calls = required_cols
.iter()
.filter(|&&index| index >= self.group_keys.len())
.map(|&index| {
let index = index - self.group_keys.len();
let agg_call = self.agg_calls[index].clone();
tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
// collect columns used in aggregate filter expressions
for i in &agg_call.filter.conjunctions {
tmp.union_with(&i.collect_input_refs(input_cnt));
}
agg_call
})
.collect_vec();
Expand Down Expand Up @@ -845,8 +880,13 @@ mod tests {

// Test case: select v1, min(v2) from test group by v1;
{
let min_v2 =
AggCall::new(AggKind::Min, vec![input_ref_2.clone().into()], false).unwrap();
let min_v2 = AggCall::new(
AggKind::Min,
vec![input_ref_2.clone().into()],
false,
Condition::true_cond(),
)
.unwrap();
let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()];
let group_exprs = vec![input_ref_1.clone().into()];

Expand All @@ -864,10 +904,20 @@ mod tests {

// Test case: select v1, min(v2) + max(v3) from t group by v1;
{
let min_v2 =
AggCall::new(AggKind::Min, vec![input_ref_2.clone().into()], false).unwrap();
let max_v3 =
AggCall::new(AggKind::Max, vec![input_ref_3.clone().into()], false).unwrap();
let min_v2 = AggCall::new(
AggKind::Min,
vec![input_ref_2.clone().into()],
false,
Condition::true_cond(),
)
.unwrap();
let max_v3 = AggCall::new(
AggKind::Max,
vec![input_ref_3.clone().into()],
false,
Condition::true_cond(),
)
.unwrap();
let func_call =
FunctionCall::new(ExprType::Add, vec![min_v2.into(), max_v3.into()]).unwrap();
let select_exprs = vec![input_ref_1.clone().into(), ExprImpl::from(func_call)];
Expand Down Expand Up @@ -900,7 +950,13 @@ mod tests {
vec![input_ref_1.into(), input_ref_3.into()],
)
.unwrap();
let agg_call = AggCall::new(AggKind::Min, vec![v1_mult_v3.into()], false).unwrap();
let agg_call = AggCall::new(
AggKind::Min,
vec![v1_mult_v3.into()],
false,
Condition::true_cond(),
)
.unwrap();
let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()];
let group_exprs = vec![input_ref_2.into()];

Expand Down Expand Up @@ -930,6 +986,7 @@ mod tests {
return_type: ty.clone(),
inputs: vec![InputRef::new(2, ty.clone())],
distinct: false,
filter: Condition::true_cond(),
};
LogicalAgg::new(vec![agg_call], vec![1], values.into())
}
Expand Down Expand Up @@ -1047,6 +1104,7 @@ mod tests {
return_type: ty.clone(),
inputs: vec![InputRef::new(2, ty.clone())],
distinct: false,
filter: Condition::true_cond(),
};
let agg = LogicalAgg::new(vec![agg_call], vec![1], values.into());

Expand Down Expand Up @@ -1110,12 +1168,14 @@ mod tests {
return_type: ty.clone(),
inputs: vec![InputRef::new(2, ty.clone())],
distinct: false,
filter: Condition::true_cond(),
},
PlanAggCall {
agg_kind: AggKind::Max,
return_type: ty.clone(),
inputs: vec![InputRef::new(1, ty.clone())],
distinct: false,
filter: Condition::true_cond(),
},
];
let agg = LogicalAgg::new(agg_calls, vec![1, 2], values.into());
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/utils/condition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::expr::{
try_get_bool_constant, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef,
};

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Condition {
/// Condition expressions in conjunction form (combined with `AND`)
pub conjunctions: Vec<ExprImpl>,
Expand Down Expand Up @@ -379,7 +379,7 @@ impl Condition {
}

#[must_use]
pub fn rewrite_expr(self, rewriter: &mut impl ExprRewriter) -> Self {
pub fn rewrite_expr(self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self {
Self {
conjunctions: self
.conjunctions
Expand Down
Loading

0 comments on commit 654f0b1

Please sign in to comment.