Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(frontend): support filter clause in aggregation #3626

Merged
merged 16 commits into from
Jul 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion src/frontend/src/binder/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use risingwave_sqlparser::ast::{Function, FunctionArg, FunctionArgExpr};
use crate::binder::bind_context::Clause;
use crate::binder::Binder;
use crate::expr::{AggCall, Expr, ExprImpl, ExprType, FunctionCall, Literal};
use crate::utils::Condition;

impl Binder {
pub(super) fn bind_function(&mut self, f: Function) -> Result<ExprImpl> {
Expand All @@ -49,9 +50,42 @@ impl Binder {
};
if let Some(kind) = agg_kind {
self.ensure_aggregate_allowed()?;
let filter = match f.filter {
Some(filter) => {
let expr = self.bind_expr(*filter)?;
if expr.return_type() != DataType::Boolean {
return Err(ErrorCode::InvalidInputSyntax(format!(
"the type of filter clause should be boolean, but found {:?}",
expr.return_type()
))
.into());
}
if expr.has_subquery() {
return Err(ErrorCode::InvalidInputSyntax(
"subquery in filter clause is not supported".to_string(),
)
.into());
}
if expr.has_agg_call() {
return Err(ErrorCode::InvalidInputSyntax(
"aggregation function in filter clause is not supported"
.to_string(),
)
.into());
}
Condition::with_expr(expr)
}
None => Condition::true_cond(),
};
return Ok(ExprImpl::AggCall(Box::new(AggCall::new(
kind, inputs, f.distinct,
kind, inputs, f.distinct, filter,
)?)));
} else if f.filter.is_some() {
return Err(ErrorCode::InvalidInputSyntax(format!(
"filter clause is only allowed in aggregation functions, but `{}` is not an aggregation function", function_name
)
)
.into());
}
let function_type = match function_name.as_str() {
// comparison
Expand Down
15 changes: 12 additions & 3 deletions src/frontend/src/expr/agg_call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@ use risingwave_common::types::DataType;
use risingwave_expr::expr::AggKind;

use super::{Expr, ExprImpl};
use crate::utils::Condition;

#[derive(Clone, Eq, PartialEq, Hash)]
pub struct AggCall {
agg_kind: AggKind,
return_type: DataType,
inputs: Vec<ExprImpl>,
distinct: bool,
filter: Condition,
}

impl std::fmt::Debug for AggCall {
Expand All @@ -34,6 +36,7 @@ impl std::fmt::Debug for AggCall {
.field("agg_kind", &self.agg_kind)
.field("return_type", &self.return_type)
.field("inputs", &self.inputs)
.field("filter", &self.filter)
.finish()
} else {
let mut builder = f.debug_tuple(&format!("{}", self.agg_kind));
Expand Down Expand Up @@ -114,19 +117,25 @@ impl AggCall {

/// Returns error if the function name matches with an existing function
/// but with illegal arguments.
pub fn new(agg_kind: AggKind, inputs: Vec<ExprImpl>, distinct: bool) -> Result<Self> {
pub fn new(
agg_kind: AggKind,
inputs: Vec<ExprImpl>,
distinct: bool,
filter: Condition,
) -> Result<Self> {
let data_types = inputs.iter().map(ExprImpl::return_type).collect_vec();
let return_type = Self::infer_return_type(&agg_kind, &data_types)?;
Ok(AggCall {
agg_kind,
return_type,
inputs,
distinct,
filter,
})
}

pub fn decompose(self) -> (AggKind, Vec<ExprImpl>, bool) {
(self.agg_kind, self.inputs, self.distinct)
pub fn decompose(self) -> (AggKind, Vec<ExprImpl>, bool, Condition) {
(self.agg_kind, self.inputs, self.distinct, self.filter)
}

pub fn agg_kind(&self) -> AggKind {
Expand Down
7 changes: 5 additions & 2 deletions src/frontend/src/expr/expr_rewriter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,15 @@ pub trait ExprRewriter {
FunctionCall::new_unchecked(func_type, inputs, ret).into()
}
fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
let (func_type, inputs, distinct) = agg_call.decompose();
let (func_type, inputs, distinct, filter) = agg_call.decompose();
let inputs = inputs
.into_iter()
.map(|expr| self.rewrite_expr(expr))
.collect();
AggCall::new(func_type, inputs, distinct).unwrap().into()
let filter = filter.rewrite_expr(self);
AggCall::new(func_type, inputs, distinct, filter)
.unwrap()
.into()
}
fn rewrite_literal(&mut self, literal: Literal) -> ExprImpl {
literal.into()
Expand Down
4 changes: 3 additions & 1 deletion src/frontend/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ impl ExprImpl {
/// A `count(*)` aggregate function.
#[inline(always)]
pub fn count_star() -> Self {
AggCall::new(AggKind::Count, vec![], false).unwrap().into()
AggCall::new(AggKind::Count, vec![], false, Condition::true_cond())
.unwrap()
.into()
}

/// Collect all `InputRef`s' indexes in the expression.
Expand Down
84 changes: 72 additions & 12 deletions src/frontend/src/optimizer/plan_node/logical_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ pub struct PlanAggCall {
pub inputs: Vec<InputRef>,

pub distinct: bool,
/// Selective aggregation: only the input rows for which
/// the filter_clause evaluates to true will be fed to aggregate function.
/// Other rows are discarded.
pub filter: Condition,
}

impl fmt::Debug for PlanAggCall {
Expand All @@ -70,6 +74,13 @@ impl fmt::Debug for PlanAggCall {
}
write!(f, ")")?;
}
if !self.filter.always_true() {
write!(
f,
" filter({:?})",
self.filter.as_expr_unless_true().unwrap()
)?;
}
Ok(())
}
}
Expand Down Expand Up @@ -99,6 +110,7 @@ impl PlanAggCall {
PlanAggCall {
agg_kind: total_agg_kind,
inputs: vec![InputRef::new(partial_output_idx, self.return_type.clone())],
filter: Condition::true_cond(),
..self.clone()
}
}
Expand All @@ -109,6 +121,7 @@ impl PlanAggCall {
return_type: DataType::Int64,
inputs: vec![],
distinct: false,
filter: Condition::true_cond(),
}
}
}
Expand Down Expand Up @@ -232,6 +245,12 @@ struct LogicalAggBuilder {
agg_calls: Vec<PlanAggCall>,
/// the error during the expression rewriting
error: Option<ErrorCode>,
/// If `is_in_filter_clause` is true, it means that
/// we are processing a filter clause.
/// This field is needed because input refs in filter clause
/// are allowed to refer to any columns, while those not in filter
/// clause are only allowed to refer to group keys.
is_in_filter_clause: bool,
}

impl LogicalAggBuilder {
Expand All @@ -257,6 +276,7 @@ impl LogicalAggBuilder {
agg_calls: vec![],
error: None,
input_proj_builder,
is_in_filter_clause: false,
})
}

Expand Down Expand Up @@ -300,14 +320,18 @@ impl ExprRewriter for LogicalAggBuilder {
/// Note that the rewriter does not traverse into inputs of agg calls.
fn rewrite_agg_call(&mut self, agg_call: AggCall) -> ExprImpl {
let return_type = agg_call.return_type();
let (agg_kind, inputs, distinct) = agg_call.decompose();

let (agg_kind, inputs, distinct, filter) = agg_call.decompose();
self.is_in_filter_clause = true;
let filter = filter.rewrite_expr(self);
self.is_in_filter_clause = false;
for i in &inputs {
if i.has_agg_call() {
self.error = Some(ErrorCode::InvalidInputSyntax(
"Aggregation calls should not be nested".into(),
));
return AggCall::new(agg_kind, inputs, distinct).unwrap().into();
return AggCall::new(agg_kind, inputs, distinct, filter)
.unwrap()
.into();
}
}

Expand All @@ -331,6 +355,7 @@ impl ExprRewriter for LogicalAggBuilder {
return_type: left_return_type.clone(),
inputs: inputs.clone(),
distinct,
filter: filter.clone(),
});
let left = ExprImpl::from(InputRef::new(
self.group_keys.len() + self.agg_calls.len() - 1,
Expand All @@ -347,6 +372,7 @@ impl ExprRewriter for LogicalAggBuilder {
return_type: right_return_type.clone(),
inputs,
distinct,
filter,
});

let right = InputRef::new(
Expand All @@ -361,6 +387,7 @@ impl ExprRewriter for LogicalAggBuilder {
return_type: return_type.clone(),
inputs,
distinct,
filter,
});
ExprImpl::from(InputRef::new(
self.group_keys.len() + self.agg_calls.len() - 1,
Expand Down Expand Up @@ -390,6 +417,8 @@ impl ExprRewriter for LogicalAggBuilder {
let expr = input_ref.into();
if let Some(group_key) = self.try_as_group_expr(&expr) {
InputRef::new(group_key, expr.return_type()).into()
} else if self.is_in_filter_clause {
InputRef::new(self.input_proj_builder.add_expr(&expr), expr.return_type()).into()
} else {
self.error = Some(ErrorCode::InvalidInputSyntax(
"column must appear in the GROUP BY clause or be used in an aggregate function"
Expand Down Expand Up @@ -540,7 +569,7 @@ impl PlanTreeNodeUnary for LogicalAgg {
fn rewrite_with_input(
&self,
input: PlanRef,
input_col_change: ColIndexMapping,
mut input_col_change: ColIndexMapping,
) -> (Self, ColIndexMapping) {
let agg_calls = self
.agg_calls
Expand All @@ -550,6 +579,7 @@ impl PlanTreeNodeUnary for LogicalAgg {
agg_call.inputs.iter_mut().for_each(|i| {
*i = InputRef::new(input_col_change.map(i.index()), i.return_type())
});
agg_call.filter = agg_call.filter.rewrite_expr(&mut input_col_change);
agg_call
})
.collect();
Expand Down Expand Up @@ -590,14 +620,19 @@ impl ColPrunable for LogicalAgg {
let group_key_required_cols = FixedBitSet::from_iter(self.group_keys.iter().copied());

let (agg_call_required_cols, agg_calls) = {
let mut tmp = FixedBitSet::with_capacity(self.input().schema().fields().len());
let input_cnt = self.input().schema().fields().len();
let mut tmp = FixedBitSet::with_capacity(input_cnt);
let new_agg_calls = required_cols
.iter()
.filter(|&&index| index >= self.group_keys.len())
.map(|&index| {
let index = index - self.group_keys.len();
let agg_call = self.agg_calls[index].clone();
tmp.extend(agg_call.inputs.iter().map(|x| x.index()));
// collect columns used in aggregate filter expressions
for i in &agg_call.filter.conjunctions {
tmp.union_with(&i.collect_input_refs(input_cnt));
}
agg_call
})
.collect_vec();
Expand Down Expand Up @@ -845,8 +880,13 @@ mod tests {

// Test case: select v1, min(v2) from test group by v1;
{
let min_v2 =
AggCall::new(AggKind::Min, vec![input_ref_2.clone().into()], false).unwrap();
let min_v2 = AggCall::new(
AggKind::Min,
vec![input_ref_2.clone().into()],
false,
Condition::true_cond(),
)
.unwrap();
let select_exprs = vec![input_ref_1.clone().into(), min_v2.into()];
let group_exprs = vec![input_ref_1.clone().into()];

Expand All @@ -864,10 +904,20 @@ mod tests {

// Test case: select v1, min(v2) + max(v3) from t group by v1;
{
let min_v2 =
AggCall::new(AggKind::Min, vec![input_ref_2.clone().into()], false).unwrap();
let max_v3 =
AggCall::new(AggKind::Max, vec![input_ref_3.clone().into()], false).unwrap();
let min_v2 = AggCall::new(
AggKind::Min,
vec![input_ref_2.clone().into()],
false,
Condition::true_cond(),
)
.unwrap();
let max_v3 = AggCall::new(
AggKind::Max,
vec![input_ref_3.clone().into()],
false,
Condition::true_cond(),
)
.unwrap();
let func_call =
FunctionCall::new(ExprType::Add, vec![min_v2.into(), max_v3.into()]).unwrap();
let select_exprs = vec![input_ref_1.clone().into(), ExprImpl::from(func_call)];
Expand Down Expand Up @@ -900,7 +950,13 @@ mod tests {
vec![input_ref_1.into(), input_ref_3.into()],
)
.unwrap();
let agg_call = AggCall::new(AggKind::Min, vec![v1_mult_v3.into()], false).unwrap();
let agg_call = AggCall::new(
AggKind::Min,
vec![v1_mult_v3.into()],
false,
Condition::true_cond(),
)
.unwrap();
let select_exprs = vec![input_ref_2.clone().into(), agg_call.into()];
let group_exprs = vec![input_ref_2.into()];

Expand Down Expand Up @@ -930,6 +986,7 @@ mod tests {
return_type: ty.clone(),
inputs: vec![InputRef::new(2, ty.clone())],
distinct: false,
filter: Condition::true_cond(),
};
LogicalAgg::new(vec![agg_call], vec![1], values.into())
}
Expand Down Expand Up @@ -1047,6 +1104,7 @@ mod tests {
return_type: ty.clone(),
inputs: vec![InputRef::new(2, ty.clone())],
distinct: false,
filter: Condition::true_cond(),
};
let agg = LogicalAgg::new(vec![agg_call], vec![1], values.into());

Expand Down Expand Up @@ -1110,12 +1168,14 @@ mod tests {
return_type: ty.clone(),
inputs: vec![InputRef::new(2, ty.clone())],
distinct: false,
filter: Condition::true_cond(),
},
PlanAggCall {
agg_kind: AggKind::Max,
return_type: ty.clone(),
inputs: vec![InputRef::new(1, ty.clone())],
distinct: false,
filter: Condition::true_cond(),
},
];
let agg = LogicalAgg::new(agg_calls, vec![1, 2], values.into());
Expand Down
4 changes: 2 additions & 2 deletions src/frontend/src/utils/condition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use crate::expr::{
try_get_bool_constant, ExprImpl, ExprRewriter, ExprType, ExprVisitor, InputRef,
};

#[derive(Debug, Clone)]
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct Condition {
/// Condition expressions in conjunction form (combined with `AND`)
pub conjunctions: Vec<ExprImpl>,
Expand Down Expand Up @@ -379,7 +379,7 @@ impl Condition {
}

#[must_use]
pub fn rewrite_expr(self, rewriter: &mut impl ExprRewriter) -> Self {
pub fn rewrite_expr(self, rewriter: &mut (impl ExprRewriter + ?Sized)) -> Self {
Self {
conjunctions: self
.conjunctions
Expand Down
Loading