diff --git a/e2e_test/batch/aggregate/selective_count.slt.part b/e2e_test/batch/aggregate/selective_count.slt.part new file mode 100644 index 0000000000000..3accedba35eff --- /dev/null +++ b/e2e_test/batch/aggregate/selective_count.slt.part @@ -0,0 +1,28 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +create table t (v1 int not null, v2 int not null, v3 int not null) + +statement ok +insert into t values (1,4,2), (2,3,3), (3,4,4), (4,3,5) + +query I +select count(*) filter (where v1 > 2) from t +---- +2 + +query I +select count(*) filter (where v1 <= v2 and v2 <= v3) from t +---- +2 + +query II +select v2, count(*) filter (where v1 > 1) as cnt from t group by v2 order by v2 +---- +3 2 +4 1 + + +statement ok +drop table t diff --git a/e2e_test/batch/aggregate/selective_distinct_agg.slt.part b/e2e_test/batch/aggregate/selective_distinct_agg.slt.part new file mode 100644 index 0000000000000..a4c7e84fbb705 --- /dev/null +++ b/e2e_test/batch/aggregate/selective_distinct_agg.slt.part @@ -0,0 +1,22 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +create table t(v1 int, v2 int, v3 int) + +statement ok +insert into t values (2, 200, -1), (1, 200, -2), (3, 100, -2), (3, 300, -1), (2, 100, -1), (1, 500, -2), (6, 500, -1) + +query I +select sum(distinct v1) filter (where v1 * 100 >= v2) from t +---- +11 + +query II +select v3, sum(distinct v1) filter(where v1 * 100 >= v2) from t group by v3 order by v3 +---- +-2 3 +-1 11 + +statement ok +drop table t diff --git a/e2e_test/batch/aggregate/selective_general_agg.slt.part b/e2e_test/batch/aggregate/selective_general_agg.slt.part new file mode 100644 index 0000000000000..7e283d35302ea --- /dev/null +++ b/e2e_test/batch/aggregate/selective_general_agg.slt.part @@ -0,0 +1,29 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +create table t(v1 int, v2 int) + +statement ok +insert into t values (2, 300), (1, 200), (3, 200), (3, 500), (2, 100), (1, 500), (6, 300) + +query II +select min(v1) filter (where v1 * 100 >= v2), min(v1) filter (where v1 * 100 < v2) from t +---- +2 1 + +query II +select v2, sum(v1) filter(where v1 >= 2) from t group by v2 order by v2 +---- +100 2 +200 3 +300 8 +500 3 + +query I +with sub(a, b) as (select sum(v1), count(v2) filter (where v1 > 5) from t) select b from sub +---- +1 + +statement ok +drop table t diff --git a/proto/expr.proto b/proto/expr.proto index 608b396ad8af0..bfa648721b906 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -153,4 +153,5 @@ message AggCall { repeated Arg args = 2; data.DataType return_type = 3; bool distinct = 4; + ExprNode filter = 5; } diff --git a/src/batch/src/executor/hash_agg.rs b/src/batch/src/executor/hash_agg.rs index 872353a7f5762..7e7f4bc4d177c 100644 --- a/src/batch/src/executor/hash_agg.rs +++ b/src/batch/src/executor/hash_agg.rs @@ -310,6 +310,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let agg_prost = HashAggNode { @@ -375,6 +376,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let agg_prost = HashAggNode { diff --git a/src/batch/src/executor/sort_agg.rs b/src/batch/src/executor/sort_agg.rs index 2715600f1a8ad..fb8170dc8a15a 100644 --- a/src/batch/src/executor/sort_agg.rs +++ b/src/batch/src/executor/sort_agg.rs @@ -332,6 +332,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let count_star = AggStateFactory::new(&prost)?.create_agg_state()?; @@ -424,6 +425,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let count_star = AggStateFactory::new(&prost)?.create_agg_state()?; @@ -551,6 +553,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?; @@ -634,6 +637,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?; @@ -756,6 +760,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?; diff --git a/src/expr/src/vector_op/agg/aggregator.rs b/src/expr/src/vector_op/agg/aggregator.rs index bfcce47f18ee4..cecac7ee98cab 100644 --- a/src/expr/src/vector_op/agg/aggregator.rs +++ b/src/expr/src/vector_op/agg/aggregator.rs @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use risingwave_common::array::*; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::*; use risingwave_pb::expr::AggCall; -use crate::expr::AggKind; +use crate::expr::{build_from_prost, AggKind, Expression, ExpressionRef, LiteralExpression}; use crate::vector_op::agg::approx_count_distinct::ApproxCountDistinct; use crate::vector_op::agg::count_star::CountStar; use crate::vector_op::agg::functions::*; @@ -55,6 +57,7 @@ pub struct AggStateFactory { agg_kind: AggKind, return_type: DataType, distinct: bool, + filter: ExpressionRef, } impl AggStateFactory { @@ -62,6 +65,12 @@ impl AggStateFactory { let return_type = DataType::from(prost.get_return_type()?); let agg_kind = AggKind::try_from(prost.get_type()?)?; let distinct = prost.distinct; + let filter: ExpressionRef = match prost.filter { + Some(ref expr) => Arc::from(build_from_prost(expr)?), + None => Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ), + }; match &prost.get_args()[..] { [ref arg] => { let input_type = DataType::from(arg.get_type()?); @@ -72,6 +81,7 @@ impl AggStateFactory { agg_kind, return_type, distinct, + filter, }) } [] => match (&agg_kind, return_type.clone()) { @@ -81,6 +91,7 @@ impl AggStateFactory { agg_kind, return_type, distinct, + filter, }), _ => Err(ErrorCode::InternalError(format!( "Agg {:?} without args not supported", @@ -101,6 +112,7 @@ impl AggStateFactory { Ok(Box::new(ApproxCountDistinct::new( self.return_type.clone(), self.input_col_idx, + self.filter.clone(), ))) } else if let Some(input_type) = self.input_type.clone() { create_agg_state_unary( @@ -109,9 +121,14 @@ impl AggStateFactory { &self.agg_kind, self.return_type.clone(), self.distinct, + self.filter.clone(), ) } else { - Ok(Box::new(CountStar::new(self.return_type.clone()))) + Ok(Box::new(CountStar::new( + self.return_type.clone(), + 0, + self.filter.clone(), + ))) } } @@ -126,6 +143,7 @@ pub fn create_agg_state_unary( agg_type: &AggKind, return_type: DataType, distinct: bool, + filter: ExpressionRef, ) -> Result> { use crate::expr::data_types::*; @@ -144,6 +162,7 @@ pub fn create_agg_state_unary( input_col_idx, $fn, $init_result, + filter )) }, ($in! { type_match_pattern }, AggKind::$agg, $ret! { type_match_pattern }, true) => { @@ -151,6 +170,7 @@ pub fn create_agg_state_unary( return_type, input_col_idx, $fn, + filter, )) }, )* @@ -244,7 +264,9 @@ mod tests { let decimal_type = DataType::Decimal; let bool_type = DataType::Boolean; let char_type = DataType::Varchar; - + let filter: ExpressionRef = Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ); macro_rules! test_create { ($input_type:expr, $agg:ident, $return_type:expr, $expected:ident) => { assert!(create_agg_state_unary( @@ -253,6 +275,7 @@ mod tests { &AggKind::$agg, $return_type.clone(), false, + filter.clone(), ) .$expected()); assert!(create_agg_state_unary( @@ -261,6 +284,7 @@ mod tests { &AggKind::$agg, $return_type.clone(), true, + filter.clone(), ) .$expected()); }; diff --git a/src/expr/src/vector_op/agg/approx_count_distinct.rs b/src/expr/src/vector_op/agg/approx_count_distinct.rs index f5129afbca2b0..1e068a5edb7f2 100644 --- a/src/expr/src/vector_op/agg/approx_count_distinct.rs +++ b/src/expr/src/vector_op/agg/approx_count_distinct.rs @@ -19,6 +19,7 @@ use risingwave_common::array::*; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::*; +use crate::expr::ExpressionRef; use crate::vector_op::agg::aggregator::Aggregator; const INDEX_BITS: u8 = 14; // number of bits used for finding the index of each 64-bit hash @@ -36,14 +37,16 @@ pub struct ApproxCountDistinct { return_type: DataType, input_col_idx: usize, registers: [u8; NUM_OF_REGISTERS], + filter: ExpressionRef, } impl ApproxCountDistinct { - pub fn new(return_type: DataType, input_col_idx: usize) -> Self { + pub fn new(return_type: DataType, input_col_idx: usize, filter: ExpressionRef) -> Self { Self { return_type, input_col_idx, registers: [0; NUM_OF_REGISTERS], + filter, } } @@ -114,6 +117,20 @@ impl ApproxCountDistinct { answer as i64 } + + /// `apply_filter_on_row` apply a filter on the given row, and return if the row satisfies the + /// filter or not # SAFETY + /// the given row must be visible + fn apply_filter_on_row(&self, input: &DataChunk, row_id: usize) -> Result { + let (row, visible) = input.row_at(row_id)?; + assert!(visible); + let filter_res = if let Some(ScalarImpl::Bool(v)) = self.filter.eval_row(&Row::from(row))? { + v + } else { + false + }; + Ok(filter_res) + } } impl Aggregator for ApproxCountDistinct { @@ -122,9 +139,12 @@ impl Aggregator for ApproxCountDistinct { } fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { - let array = input.column_at(self.input_col_idx).array_ref(); - let datum_ref = array.value_at(row_id); - self.add_datum(datum_ref); + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + let array = input.column_at(self.input_col_idx).array_ref(); + let datum_ref = array.value_at(row_id); + self.add_datum(datum_ref); + } Ok(()) } @@ -135,12 +155,12 @@ impl Aggregator for ApproxCountDistinct { end_row_id: usize, ) -> Result<()> { let array = input.column_at(self.input_col_idx).array_ref(); - for datum_ref in array - .iter() - .skip(start_row_id) - .take(end_row_id - start_row_id) - { - self.add_datum(datum_ref); + for row_id in start_row_id..end_row_id { + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + let datum_ref = array.value_at(row_id); + self.add_datum(datum_ref); + } } Ok(()) } @@ -168,8 +188,9 @@ mod tests { use risingwave_common::array::{ ArrayBuilder, ArrayBuilderImpl, DataChunk, I32Array, I64ArrayBuilder, }; - use risingwave_common::types::DataType; + use risingwave_common::types::{DataType, ScalarImpl}; + use crate::expr::{Expression, LiteralExpression}; use crate::vector_op::agg::aggregator::Aggregator; use crate::vector_op::agg::approx_count_distinct::ApproxCountDistinct; @@ -193,7 +214,13 @@ mod tests { let inputs_size: [usize; 3] = [20000, 10000, 5000]; let inputs_start: [i32; 3] = [0, 20000, 30000]; - let mut agg = ApproxCountDistinct::new(DataType::Int64, 0); + let mut agg = ApproxCountDistinct::new( + DataType::Int64, + 0, + Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ), + ); let mut builder = ArrayBuilderImpl::Int64(I64ArrayBuilder::new(3)); for i in 0..3 { @@ -213,7 +240,13 @@ mod tests { let inputs_size: [usize; 3] = [20000, 10000, 5000]; let inputs_start: [i32; 3] = [0, 20000, 30000]; - let mut agg = ApproxCountDistinct::new(DataType::Int64, 0); + let mut agg = ApproxCountDistinct::new( + DataType::Int64, + 0, + Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ), + ); let mut builder = ArrayBuilderImpl::Int64(I64ArrayBuilder::new(3)); for i in 0..3 { diff --git a/src/expr/src/vector_op/agg/count_star.rs b/src/expr/src/vector_op/agg/count_star.rs index 11017ba324afe..ac8cbc5f687d8 100644 --- a/src/expr/src/vector_op/agg/count_star.rs +++ b/src/expr/src/vector_op/agg/count_star.rs @@ -16,20 +16,37 @@ use risingwave_common::array::*; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::*; +use crate::expr::ExpressionRef; use crate::vector_op::agg::aggregator::Aggregator; pub struct CountStar { return_type: DataType, result: usize, + filter: ExpressionRef, } impl CountStar { - pub fn new(return_type: DataType) -> Self { + pub fn new(return_type: DataType, result: usize, filter: ExpressionRef) -> Self { Self { return_type, - result: 0, + result, + filter, } } + + /// `apply_filter_on_row` apply a filter on the given row, and return if the row satisfies the + /// filter or not # SAFETY + /// the given row must be visible + fn apply_filter_on_row(&self, input: &DataChunk, row_id: usize) -> Result { + let (row, visible) = input.row_at(row_id)?; + assert!(visible); + let filter_res = if let Some(ScalarImpl::Bool(v)) = self.filter.eval_row(&Row::from(row))? { + v + } else { + false + }; + Ok(filter_res) + } } impl Aggregator for CountStar { @@ -37,17 +54,6 @@ impl Aggregator for CountStar { self.return_type.clone() } - fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { - if let Some(visibility) = input.visibility() { - if visibility.is_set(row_id)? { - self.result += 1; - } - } else { - self.result += 1; - } - Ok(()) - } - fn update_multi( &mut self, input: &DataChunk, @@ -56,26 +62,52 @@ impl Aggregator for CountStar { ) -> Result<()> { if let Some(visibility) = input.visibility() { for row_id in start_row_id..end_row_id { - if visibility.is_set(row_id)? { + if visibility.is_set(row_id)? && self.apply_filter_on_row(input, row_id)? { self.result += 1; } } } else { - self.result += end_row_id - start_row_id; + self.result += self + .filter + .eval(input)? + .iter() + .skip(start_row_id) + .take(end_row_id - start_row_id) + .filter(|res| { + res.map(|x| *x.into_scalar_impl().as_bool()) + .unwrap_or(false) + }) + .count(); } Ok(()) } - fn output_and_reset(&mut self, builder: &mut ArrayBuilderImpl) -> Result<()> { - let res = self.output(builder); - self.result = 0; - res - } - fn output(&self, builder: &mut ArrayBuilderImpl) -> Result<()> { match builder { ArrayBuilderImpl::Int64(b) => b.append(Some(self.result as i64)).map_err(Into::into), _ => Err(ErrorCode::InternalError("Unexpected builder for count(*).".into()).into()), } } + + fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { + if let (row, true) = input.row_at(row_id)? { + let filter_res = + if let Some(ScalarImpl::Bool(v)) = self.filter.eval_row(&Row::from(row))? { + v + } else { + false + }; + + if filter_res { + self.result += 1; + } + } + Ok(()) + } + + fn output_and_reset(&mut self, builder: &mut ArrayBuilderImpl) -> Result<()> { + let res = self.output(builder); + self.result = 0; + res + } } diff --git a/src/expr/src/vector_op/agg/general_agg.rs b/src/expr/src/vector_op/agg/general_agg.rs index 404e06647ed21..5f63f40218510 100644 --- a/src/expr/src/vector_op/agg/general_agg.rs +++ b/src/expr/src/vector_op/agg/general_agg.rs @@ -18,6 +18,7 @@ use risingwave_common::array::*; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::*; +use crate::expr::ExpressionRef; use crate::vector_op::agg::aggregator::Aggregator; use crate::vector_op::agg::functions::RTFn; @@ -32,6 +33,7 @@ where init_result: Option, result: Option, f: F, + filter: ExpressionRef, _phantom: PhantomData, } impl GeneralAgg @@ -45,6 +47,7 @@ where input_col_idx: usize, f: F, init_result: Option, + filter: ExpressionRef, ) -> Self { Self { return_type, @@ -52,35 +55,45 @@ where init_result: init_result.clone(), result: init_result, f, + filter, _phantom: PhantomData, } } - pub(super) fn update_single_concrete(&mut self, input: &T, row_id: usize) -> Result<()> { - let datum = self - .f - .eval( - self.result.as_ref().map(|x| x.as_scalar_ref()), - input.value_at(row_id), - )? - .map(|x| x.to_owned_scalar()); - self.result = datum; + pub(super) fn update_single_concrete( + &mut self, + array: &T, + input: &DataChunk, + row_id: usize, + ) -> Result<()> { + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + let tmp = self + .f + .eval( + self.result.as_ref().map(|x| x.as_scalar_ref()), + array.value_at(row_id), + )? + .map(|x| x.to_owned_scalar()); + self.result = tmp; + } Ok(()) } pub(super) fn update_multi_concrete( &mut self, - input: &T, + array: &T, + input: &DataChunk, start_row_id: usize, end_row_id: usize, ) -> Result<()> { let mut cur = self.result.as_ref().map(|x| x.as_scalar_ref()); - for i in input - .iter() - .skip(start_row_id) - .take(end_row_id - start_row_id) - { - cur = self.f.eval(cur, i)?; + for row_id in start_row_id..end_row_id { + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + let datum_ref = array.value_at(row_id); + cur = self.f.eval(cur, datum_ref)?; + } } self.result = cur.map(|x| x.to_owned_scalar()); Ok(()) @@ -96,6 +109,21 @@ where self.result = self.init_result.clone(); res } + + /// `apply_filter_on_row` apply a filter on the given row, and return if the row satisfies the + /// filter or not + /// # SAFETY + /// the given row must be visible + fn apply_filter_on_row(&self, input: &DataChunk, row_id: usize) -> Result { + let (row, visible) = input.row_at(row_id)?; + assert!(visible); + let filter_res = if let Some(ScalarImpl::Bool(v)) = self.filter.eval_row(&Row::from(row))? { + v + } else { + false + }; + Ok(filter_res) + } } macro_rules! impl_aggregator { @@ -112,7 +140,7 @@ macro_rules! impl_aggregator { if let ArrayImpl::$input_variant(i) = input.column_at(self.input_col_idx).array_ref() { - self.update_single_concrete(i, row_id) + self.update_single_concrete(i, input, row_id) } else { Err(ErrorCode::InternalError(format!( "Input fail to match {}.", @@ -131,7 +159,7 @@ macro_rules! impl_aggregator { if let ArrayImpl::$input_variant(i) = input.column_at(self.input_col_idx).array_ref() { - self.update_multi_concrete(i, start_row_id, end_row_id) + self.update_multi_concrete(i, input, start_row_id, end_row_id) } else { Err(ErrorCode::InternalError(format!( "Input fail to match {} or builder fail to match {}.", @@ -211,7 +239,7 @@ mod tests { use risingwave_common::types::Decimal; use super::*; - use crate::expr::AggKind; + use crate::expr::{AggKind, Expression, LiteralExpression}; use crate::vector_op::agg::aggregator::create_agg_state_unary; fn eval_agg( @@ -223,7 +251,11 @@ mod tests { ) -> Result { let len = input.len(); let input_chunk = DataChunk::new(vec![Column::new(input)], len); - let mut agg_state = create_agg_state_unary(input_type, 0, agg_type, return_type, false)?; + let filter: ExpressionRef = Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ); + let mut agg_state = + create_agg_state_unary(input_type, 0, agg_type, return_type, false, filter)?; agg_state.update_multi(&input_chunk, 0, input_chunk.cardinality())?; agg_state.output(&mut builder)?; builder.finish().map_err(Into::into) diff --git a/src/expr/src/vector_op/agg/general_distinct_agg.rs b/src/expr/src/vector_op/agg/general_distinct_agg.rs index 4364aff3bed10..e979c177add5b 100644 --- a/src/expr/src/vector_op/agg/general_distinct_agg.rs +++ b/src/expr/src/vector_op/agg/general_distinct_agg.rs @@ -19,6 +19,7 @@ use risingwave_common::array::*; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::*; +use crate::expr::ExpressionRef; use crate::vector_op::agg::aggregator::Aggregator; use crate::vector_op::agg::functions::RTFn; @@ -39,6 +40,7 @@ where result: Option, f: F, exists: HashSet, + filter: ExpressionRef, _phantom: PhantomData, } impl GeneralDistinctAgg @@ -47,7 +49,7 @@ where F: for<'a> RTFn<'a, T, R>, R: Array, { - pub fn new(return_type: DataType, input_col_idx: usize, f: F) -> Self { + pub fn new(return_type: DataType, input_col_idx: usize, f: F, filter: ExpressionRef) -> Self { Self { return_type, input_col_idx, @@ -55,6 +57,7 @@ where f, exists: HashSet::new(), _phantom: PhantomData, + filter, } } @@ -77,22 +80,22 @@ where fn update_multi_concrete( &mut self, - input: &T, + array: &T, + input: &DataChunk, start_row_id: usize, end_row_id: usize, ) -> Result<()> { - let input = input - .iter() - .skip(start_row_id) - .take(end_row_id - start_row_id) - .filter(|scalar_ref| { - self.exists.insert( - scalar_ref.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value()), - ) - }); let mut cur = self.result.as_ref().map(|x| x.as_scalar_ref()); - for datum in input { - cur = self.f.eval(cur, datum)?; + for row_id in start_row_id..end_row_id { + if self.apply_filter_on_row(input, row_id)? { + let datum = array.value_at(row_id); + if self + .exists + .insert(datum.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value())) + { + cur = self.f.eval(cur, datum)?; + } + } } let r = cur.map(|x| x.to_owned_scalar()); self.result = r; @@ -110,6 +113,19 @@ where self.result = None; res } + + fn apply_filter_on_row(&self, input: &DataChunk, row_id: usize) -> Result { + let (row, visible) = input.row_at(row_id)?; + // SAFETY: when performing agg, the data chunk should already be + // compacted. + assert!(visible); + let filter_res = if let Some(ScalarImpl::Bool(v)) = self.filter.eval_row(&Row::from(row))? { + v + } else { + false + }; + Ok(filter_res) + } } macro_rules! impl_aggregator { @@ -126,7 +142,11 @@ macro_rules! impl_aggregator { if let ArrayImpl::$input_variant(i) = input.column_at(self.input_col_idx).array_ref() { - self.update_single_concrete(i, row_id) + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + self.update_single_concrete(i, row_id)?; + } + Ok(()) } else { Err(ErrorCode::InternalError(format!( "Input fail to match {}.", @@ -145,7 +165,7 @@ macro_rules! impl_aggregator { if let ArrayImpl::$input_variant(i) = input.column_at(self.input_col_idx).array_ref() { - self.update_multi_concrete(i, start_row_id, end_row_id) + self.update_multi_concrete(i, input, start_row_id, end_row_id) } else { Err(ErrorCode::InternalError(format!( "Input fail to match {}.", @@ -224,7 +244,7 @@ mod tests { use risingwave_common::types::Decimal; use super::*; - use crate::expr::AggKind; + use crate::expr::{AggKind, Expression, LiteralExpression}; use crate::vector_op::agg::aggregator::create_agg_state_unary; fn eval_agg( @@ -236,7 +256,16 @@ mod tests { ) -> Result { let len = input.len(); let input_chunk = DataChunk::new(vec![Column::new(input)], len); - let mut agg_state = create_agg_state_unary(input_type, 0, agg_type, return_type, true)?; + let mut agg_state = create_agg_state_unary( + input_type, + 0, + agg_type, + return_type, + true, + Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ), + )?; agg_state.update_multi(&input_chunk, 0, input_chunk.cardinality())?; agg_state.output(&mut builder)?; builder.finish().map_err(Into::into) diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index 78bbb2590f1e7..7b5ec56eeadc9 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -161,6 +161,10 @@ impl PlanAggCall { return_type: Some(self.return_type.to_protobuf()), args: self.inputs.iter().map(InputRef::to_agg_arg_proto).collect(), distinct: self.distinct, + filter: self + .filter + .as_expr_unless_true() + .map(|expr| expr.to_expr_proto()), } } diff --git a/src/frontend/src/optimizer/rule/distinct_agg.rs b/src/frontend/src/optimizer/rule/distinct_agg.rs index 8dcbb0d933d22..b859dc81e183e 100644 --- a/src/frontend/src/optimizer/rule/distinct_agg.rs +++ b/src/frontend/src/optimizer/rule/distinct_agg.rs @@ -161,7 +161,7 @@ impl DistinctAgg { ExprType::GreaterThan, vec![ InputRef::new(index_of_middle_agg, DataType::Int64).into(), - Literal::new(Some(0.into()), DataType::Int64).into(), + Literal::new(Some(0_i64.into()), DataType::Int64).into(), ], ) .unwrap(); diff --git a/src/frontend/src/stream_fragmenter/mod.rs b/src/frontend/src/stream_fragmenter/mod.rs index 60636bc0dd957..3bae3c0d0d461 100644 --- a/src/frontend/src/stream_fragmenter/mod.rs +++ b/src/frontend/src/stream_fragmenter/mod.rs @@ -397,6 +397,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, } } diff --git a/src/meta/src/stream/test_fragmenter.rs b/src/meta/src/stream/test_fragmenter.rs index b86afe7e2f29d..2f3249e943466 100644 --- a/src/meta/src/stream/test_fragmenter.rs +++ b/src/meta/src/stream/test_fragmenter.rs @@ -64,6 +64,7 @@ fn make_sum_aggcall(idx: i32) -> AggCall { ..Default::default() }), distinct: false, + filter: None, } }