From 6c5e54d5a0d2020003f3e870f2b627505ac19678 Mon Sep 17 00:00:00 2001 From: mgt Date: Wed, 6 Jul 2022 14:18:52 +0800 Subject: [PATCH 01/14] feat(batch): support selective aggregation for count --- proto/expr.proto | 1 + src/batch/src/executor/hash_agg.rs | 2 + src/batch/src/executor/sort_agg.rs | 5 ++ src/expr/src/vector_op/agg/aggregator.rs | 17 ++++++- src/expr/src/vector_op/agg/count_star.rs | 49 ++++++++++++++++--- .../vector_op/agg/general_sorted_grouper.rs | 1 + .../src/optimizer/plan_node/logical_agg.rs | 4 ++ src/frontend/src/stream_fragmenter/mod.rs | 1 + src/meta/src/stream/test_fragmenter.rs | 1 + 9 files changed, 71 insertions(+), 10 deletions(-) diff --git a/proto/expr.proto b/proto/expr.proto index d137dec671a8a..98245309f4da3 100644 --- a/proto/expr.proto +++ b/proto/expr.proto @@ -151,4 +151,5 @@ message AggCall { repeated Arg args = 2; data.DataType return_type = 3; bool distinct = 4; + ExprNode filter = 5; } diff --git a/src/batch/src/executor/hash_agg.rs b/src/batch/src/executor/hash_agg.rs index 29c101071ebf3..50334aec730a6 100644 --- a/src/batch/src/executor/hash_agg.rs +++ b/src/batch/src/executor/hash_agg.rs @@ -310,6 +310,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let agg_prost = HashAggNode { @@ -375,6 +376,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let agg_prost = HashAggNode { diff --git a/src/batch/src/executor/sort_agg.rs b/src/batch/src/executor/sort_agg.rs index ee84d2601ff7c..fd27001a5a744 100644 --- a/src/batch/src/executor/sort_agg.rs +++ b/src/batch/src/executor/sort_agg.rs @@ -319,6 +319,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let count_star = AggStateFactory::new(&prost)?.create_agg_state()?; @@ -411,6 +412,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let count_star = AggStateFactory::new(&prost)?.create_agg_state()?; @@ -538,6 +540,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?; @@ -621,6 +624,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?; @@ -743,6 +747,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?; diff --git a/src/expr/src/vector_op/agg/aggregator.rs b/src/expr/src/vector_op/agg/aggregator.rs index a54d8873a9b81..b08a0d701b185 100644 --- a/src/expr/src/vector_op/agg/aggregator.rs +++ b/src/expr/src/vector_op/agg/aggregator.rs @@ -12,12 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::sync::Arc; + use risingwave_common::array::*; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::*; use risingwave_pb::expr::AggCall; -use crate::expr::AggKind; +use crate::expr::{build_from_prost, AggKind, ExpressionRef}; use crate::vector_op::agg::approx_count_distinct::ApproxCountDistinct; use crate::vector_op::agg::count_star::CountStar; use crate::vector_op::agg::functions::*; @@ -64,6 +66,7 @@ pub struct AggStateFactory { agg_kind: AggKind, return_type: DataType, distinct: bool, + filter: Option, } impl AggStateFactory { @@ -71,6 +74,10 @@ impl AggStateFactory { let return_type = DataType::from(prost.get_return_type()?); let agg_kind = AggKind::try_from(prost.get_type()?)?; let distinct = prost.distinct; + let filter = match prost.filter { + Some(ref expr) => Some(Arc::from(build_from_prost(expr)?)), + None => None, + }; match &prost.get_args()[..] { [ref arg] => { let input_type = DataType::from(arg.get_type()?); @@ -81,6 +88,7 @@ impl AggStateFactory { agg_kind, return_type, distinct, + filter, }) } [] => match (&agg_kind, return_type.clone()) { @@ -90,6 +98,7 @@ impl AggStateFactory { agg_kind, return_type, distinct, + filter, }), _ => Err(ErrorCode::InternalError(format!( "Agg {:?} without args not supported", @@ -120,7 +129,11 @@ impl AggStateFactory { self.distinct, ) } else { - Ok(Box::new(CountStar::new(self.return_type.clone(), 0))) + Ok(Box::new(CountStar::new( + self.return_type.clone(), + 0, + self.filter.clone(), + ))) } } diff --git a/src/expr/src/vector_op/agg/count_star.rs b/src/expr/src/vector_op/agg/count_star.rs index d12cd7cac78ec..13d3776adadb3 100644 --- a/src/expr/src/vector_op/agg/count_star.rs +++ b/src/expr/src/vector_op/agg/count_star.rs @@ -16,6 +16,7 @@ use risingwave_common::array::*; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::*; +use crate::expr::ExpressionRef; use crate::vector_op::agg::aggregator::Aggregator; use crate::vector_op::agg::general_sorted_grouper::EqGroups; @@ -23,14 +24,16 @@ pub struct CountStar { return_type: DataType, result: usize, reached_limit: bool, + filter: Option, } impl CountStar { - pub fn new(return_type: DataType, result: usize) -> Self { + pub fn new(return_type: DataType, result: usize, filter: Option) -> Self { Self { return_type, result, reached_limit: false, + filter, } } } @@ -41,7 +44,19 @@ impl Aggregator for CountStar { } fn update(&mut self, input: &DataChunk) -> Result<()> { - self.result += input.cardinality(); + if let Some(ref filter) = self.filter { + self.result += filter + .eval(input)? + .iter() + .filter(|res| match res { + Some(scalar) => *scalar.into_scalar_impl().as_bool(), + _ => false, + }) + .count(); + } else { + self.result += input.cardinality(); + } + Ok(()) } @@ -75,6 +90,18 @@ impl Aggregator for CountStar { // in the process of counting, we set the `reached_limit` flag and save the start // index of previous group to `self.result`. let mut groups_iter = groups.starting_indices().iter(); + let filter_cnt = if let Some(ref filter) = self.filter { + filter + .eval(input)? + .iter() + .filter(|res| match res { + Some(scalar) => *scalar.into_scalar_impl().as_bool(), + _ => false, + }) + .count() + } else { + input.cardinality() + }; if let Some(first) = groups_iter.next() { let first_count = { if self.reached_limit { @@ -100,22 +127,28 @@ impl Aggregator for CountStar { } if group_cnt == groups.len() { self.reached_limit = false; - self.result = input.cardinality() - prev; + self.result = filter_cnt - prev; } } else { - self.result += input.cardinality(); + self.result += filter_cnt; } Ok(()) } fn update_with_row(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { - if let Some(visibility) = input.visibility() { - if visibility.is_set(row_id)? { + if let (row, true) = input.row_at(row_id)? { + let filter_res = if let Some(ref filter) = self.filter { + match filter.eval_row(&Row::from(row))? { + Some(scalar) => *scalar.as_bool(), + None => false, + } + } else { + true + }; + if filter_res { self.result += 1; } - } else { - self.result += 1; } Ok(()) } diff --git a/src/expr/src/vector_op/agg/general_sorted_grouper.rs b/src/expr/src/vector_op/agg/general_sorted_grouper.rs index 98239b7a85a8e..6c2ddb540ebae 100644 --- a/src/expr/src/vector_op/agg/general_sorted_grouper.rs +++ b/src/expr/src/vector_op/agg/general_sorted_grouper.rs @@ -411,6 +411,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, }; let mut agg = AggStateFactory::new(&prost) .unwrap() diff --git a/src/frontend/src/optimizer/plan_node/logical_agg.rs b/src/frontend/src/optimizer/plan_node/logical_agg.rs index cc9f4463870de..ba33a1c8a9bd2 100644 --- a/src/frontend/src/optimizer/plan_node/logical_agg.rs +++ b/src/frontend/src/optimizer/plan_node/logical_agg.rs @@ -92,6 +92,10 @@ impl PlanAggCall { return_type: Some(self.return_type.to_protobuf()), args: self.inputs.iter().map(InputRef::to_agg_arg_proto).collect(), distinct: self.distinct, + filter: self + .filter + .as_expr_unless_true() + .map(|expr| expr.to_expr_proto()), } } diff --git a/src/frontend/src/stream_fragmenter/mod.rs b/src/frontend/src/stream_fragmenter/mod.rs index 1a80280f16828..c5ddfc6314045 100644 --- a/src/frontend/src/stream_fragmenter/mod.rs +++ b/src/frontend/src/stream_fragmenter/mod.rs @@ -387,6 +387,7 @@ mod tests { ..Default::default() }), distinct: false, + filter: None, } } diff --git a/src/meta/src/stream/test_fragmenter.rs b/src/meta/src/stream/test_fragmenter.rs index 374000672ef47..c769d40e39b35 100644 --- a/src/meta/src/stream/test_fragmenter.rs +++ b/src/meta/src/stream/test_fragmenter.rs @@ -77,6 +77,7 @@ fn make_sum_aggcall(idx: i32) -> AggCall { ..Default::default() }), distinct: false, + filter: None, } } From d740deef34d2f70c4078f36be080c91f478e4388 Mon Sep 17 00:00:00 2001 From: mgt Date: Wed, 6 Jul 2022 14:36:32 +0800 Subject: [PATCH 02/14] add e2e test for selective count --- .../batch/aggregate/selective_count.slt.part | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 e2e_test/batch/aggregate/selective_count.slt.part diff --git a/e2e_test/batch/aggregate/selective_count.slt.part b/e2e_test/batch/aggregate/selective_count.slt.part new file mode 100644 index 0000000000000..ebfccb376f143 --- /dev/null +++ b/e2e_test/batch/aggregate/selective_count.slt.part @@ -0,0 +1,21 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +create table t (v1 int not null, v2 int not null, v3 int not null) + +statement ok +insert into t values (1,4,2), (2,3,3), (3,4,4), (4,3,5) + +query I +select count(*) filter (where v1 > 2) from t +---- +2 + +query I +select count(*) filter (where v1 <= v2 and v2 <= v3) from t; +---- +2 + +statement ok +drop table t From 24917fde1c08d34f07bb26beb2eac1fc9caeb490 Mon Sep 17 00:00:00 2001 From: mgt Date: Wed, 6 Jul 2022 14:46:55 +0800 Subject: [PATCH 03/14] add group by test for selective count --- e2e_test/batch/aggregate/selective_count.slt.part | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/e2e_test/batch/aggregate/selective_count.slt.part b/e2e_test/batch/aggregate/selective_count.slt.part index ebfccb376f143..30ae47f7033ca 100644 --- a/e2e_test/batch/aggregate/selective_count.slt.part +++ b/e2e_test/batch/aggregate/selective_count.slt.part @@ -17,5 +17,12 @@ select count(*) filter (where v1 <= v2 and v2 <= v3) from t; ---- 2 +query I +select v2, count(*) filter (where v1 > 1) as cnt from t group by v2 order by v2 ; +---- +3 2 +4 1 + + statement ok drop table t From 6ebd790dad2731adf05865db06870fb636e93376 Mon Sep 17 00:00:00 2001 From: mgt Date: Wed, 6 Jul 2022 16:41:23 +0800 Subject: [PATCH 04/14] remove option --- src/expr/src/vector_op/agg/aggregator.rs | 12 +++-- src/expr/src/vector_op/agg/count_star.rs | 67 ++++++++++++------------ 2 files changed, 40 insertions(+), 39 deletions(-) diff --git a/src/expr/src/vector_op/agg/aggregator.rs b/src/expr/src/vector_op/agg/aggregator.rs index b08a0d701b185..666d704827ffb 100644 --- a/src/expr/src/vector_op/agg/aggregator.rs +++ b/src/expr/src/vector_op/agg/aggregator.rs @@ -19,7 +19,7 @@ use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::*; use risingwave_pb::expr::AggCall; -use crate::expr::{build_from_prost, AggKind, ExpressionRef}; +use crate::expr::{build_from_prost, AggKind, Expression, ExpressionRef, LiteralExpression}; use crate::vector_op::agg::approx_count_distinct::ApproxCountDistinct; use crate::vector_op::agg::count_star::CountStar; use crate::vector_op::agg::functions::*; @@ -66,7 +66,7 @@ pub struct AggStateFactory { agg_kind: AggKind, return_type: DataType, distinct: bool, - filter: Option, + filter: ExpressionRef, } impl AggStateFactory { @@ -74,9 +74,11 @@ impl AggStateFactory { let return_type = DataType::from(prost.get_return_type()?); let agg_kind = AggKind::try_from(prost.get_type()?)?; let distinct = prost.distinct; - let filter = match prost.filter { - Some(ref expr) => Some(Arc::from(build_from_prost(expr)?)), - None => None, + let filter: ExpressionRef = match prost.filter { + Some(ref expr) => Arc::from(build_from_prost(expr)?), + None => Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ), }; match &prost.get_args()[..] { [ref arg] => { diff --git a/src/expr/src/vector_op/agg/count_star.rs b/src/expr/src/vector_op/agg/count_star.rs index 13d3776adadb3..d832eb09c2724 100644 --- a/src/expr/src/vector_op/agg/count_star.rs +++ b/src/expr/src/vector_op/agg/count_star.rs @@ -24,11 +24,11 @@ pub struct CountStar { return_type: DataType, result: usize, reached_limit: bool, - filter: Option, + filter: ExpressionRef, } impl CountStar { - pub fn new(return_type: DataType, result: usize, filter: Option) -> Self { + pub fn new(return_type: DataType, result: usize, filter: ExpressionRef) -> Self { Self { return_type, result, @@ -44,18 +44,18 @@ impl Aggregator for CountStar { } fn update(&mut self, input: &DataChunk) -> Result<()> { - if let Some(ref filter) = self.filter { - self.result += filter - .eval(input)? - .iter() - .filter(|res| match res { - Some(scalar) => *scalar.into_scalar_impl().as_bool(), - _ => false, - }) - .count(); - } else { - self.result += input.cardinality(); - } + self.result += self + .filter + .eval(input)? + .iter() + .filter(|res| { + if let Some(ScalarRefImpl::Bool(v)) = res { + *v + } else { + false + } + }) + .count(); Ok(()) } @@ -90,18 +90,18 @@ impl Aggregator for CountStar { // in the process of counting, we set the `reached_limit` flag and save the start // index of previous group to `self.result`. let mut groups_iter = groups.starting_indices().iter(); - let filter_cnt = if let Some(ref filter) = self.filter { - filter - .eval(input)? - .iter() - .filter(|res| match res { - Some(scalar) => *scalar.into_scalar_impl().as_bool(), - _ => false, - }) - .count() - } else { - input.cardinality() - }; + let filter_cnt = self + .filter + .eval(input)? + .iter() + .filter(|res| { + if let Some(ScalarRefImpl::Bool(v)) = res { + *v + } else { + false + } + }) + .count(); if let Some(first) = groups_iter.next() { let first_count = { if self.reached_limit { @@ -138,14 +138,13 @@ impl Aggregator for CountStar { fn update_with_row(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { if let (row, true) = input.row_at(row_id)? { - let filter_res = if let Some(ref filter) = self.filter { - match filter.eval_row(&Row::from(row))? { - Some(scalar) => *scalar.as_bool(), - None => false, - } - } else { - true - }; + let filter_res = + if let Some(ScalarImpl::Bool(v)) = self.filter.eval_row(&Row::from(row))? { + v + } else { + false + }; + if filter_res { self.result += 1; } From d6e9d58692c1944a653541612425e84127afd5ee Mon Sep 17 00:00:00 2001 From: mgt Date: Thu, 7 Jul 2022 15:27:30 +0800 Subject: [PATCH 05/14] add filter to approx count distinct --- src/expr/src/vector_op/agg/aggregator.rs | 1 + .../vector_op/agg/approx_count_distinct.rs | 64 +++++++++++++++---- src/expr/src/vector_op/agg/count_star.rs | 16 ++--- 3 files changed, 56 insertions(+), 25 deletions(-) diff --git a/src/expr/src/vector_op/agg/aggregator.rs b/src/expr/src/vector_op/agg/aggregator.rs index 666d704827ffb..2c54b9096b232 100644 --- a/src/expr/src/vector_op/agg/aggregator.rs +++ b/src/expr/src/vector_op/agg/aggregator.rs @@ -121,6 +121,7 @@ impl AggStateFactory { Ok(Box::new(ApproxCountDistinct::new( self.return_type.clone(), self.input_col_idx, + self.filter.clone() ))) } else if let Some(input_type) = self.input_type.clone() { create_agg_state_unary( diff --git a/src/expr/src/vector_op/agg/approx_count_distinct.rs b/src/expr/src/vector_op/agg/approx_count_distinct.rs index 152823093c72c..198d2b83d1c42 100644 --- a/src/expr/src/vector_op/agg/approx_count_distinct.rs +++ b/src/expr/src/vector_op/agg/approx_count_distinct.rs @@ -19,6 +19,7 @@ use risingwave_common::array::*; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::*; +use crate::expr::ExpressionRef; use crate::vector_op::agg::aggregator::Aggregator; use crate::vector_op::agg::general_sorted_grouper::EqGroups; @@ -37,14 +38,16 @@ pub struct ApproxCountDistinct { return_type: DataType, input_col_idx: usize, registers: [u8; NUM_OF_REGISTERS], + filter: ExpressionRef, } impl ApproxCountDistinct { - pub fn new(return_type: DataType, input_col_idx: usize) -> Self { + pub fn new(return_type: DataType, input_col_idx: usize, filter: ExpressionRef) -> Self { Self { return_type, input_col_idx, registers: [0; NUM_OF_REGISTERS], + filter, } } @@ -115,6 +118,19 @@ impl ApproxCountDistinct { answer as i64 } + + fn apply_filter_on_row(&self, input: &DataChunk, row_id: usize) -> Result { + let (row, visible) = input.row_at(row_id)?; + // SAFETY: when performing approx_count_distinct, the data chunk should already be + // compacted. + assert!(visible); + let filter_res = if let Some(ScalarImpl::Bool(v)) = self.filter.eval_row(&Row::from(row))? { + v + } else { + false + }; + Ok(filter_res) + } } impl Aggregator for ApproxCountDistinct { @@ -123,17 +139,23 @@ impl Aggregator for ApproxCountDistinct { } fn update_with_row(&mut self, input: &DataChunk, row_id: usize) -> Result<()> { - let array = input.column_at(self.input_col_idx).array_ref(); - let datum_ref = array.value_at(row_id); - self.add_datum(datum_ref); - + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + let array = input.column_at(self.input_col_idx).array_ref(); + let datum_ref = array.value_at(row_id); + self.add_datum(datum_ref); + } Ok(()) } fn update(&mut self, input: &DataChunk) -> Result<()> { let array = input.column_at(self.input_col_idx).array_ref(); - for datum_ref in array.iter() { - self.add_datum(datum_ref); + for row_id in 0..array.len() { + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + let datum_ref = array.value_at(row_id); + self.add_datum(datum_ref); + } } Ok(()) } @@ -166,7 +188,7 @@ impl Aggregator for ApproxCountDistinct { let mut group_cnt = 0; let mut groups_iter = groups.starting_indices().iter().peekable(); let chunk_offset = groups.chunk_offset(); - for (i, datum_ref) in array.iter().skip(chunk_offset).enumerate() { + for (row_id, (i, datum_ref)) in array.iter().enumerate().skip(chunk_offset).enumerate() { // reset state and output result when new group is found if groups_iter.peek() == Some(&&i) { groups_iter.next(); @@ -174,9 +196,10 @@ impl Aggregator for ApproxCountDistinct { builder.append(Some(self.calculate_result()))?; self.registers = [0; NUM_OF_REGISTERS]; } - - self.add_datum(datum_ref); - + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + self.add_datum(datum_ref); + } // reset state and exit when reach limit if groups.is_reach_limit(group_cnt) { self.registers = [0; NUM_OF_REGISTERS]; @@ -196,8 +219,9 @@ mod tests { use risingwave_common::array::{ ArrayBuilder, ArrayBuilderImpl, DataChunk, I32Array, I64ArrayBuilder, }; - use risingwave_common::types::DataType; + use risingwave_common::types::{DataType, ScalarImpl}; + use crate::expr::{Expression, LiteralExpression}; use crate::vector_op::agg::aggregator::Aggregator; use crate::vector_op::agg::approx_count_distinct::ApproxCountDistinct; use crate::vector_op::agg::EqGroups; @@ -222,7 +246,13 @@ mod tests { let inputs_size: [usize; 3] = [20000, 10000, 5000]; let inputs_start: [i32; 3] = [0, 20000, 30000]; - let mut agg = ApproxCountDistinct::new(DataType::Int64, 0); + let mut agg = ApproxCountDistinct::new( + DataType::Int64, + 0, + Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ), + ); let mut builder = ArrayBuilderImpl::Int64(I64ArrayBuilder::new(3)); for i in 0..3 { @@ -237,7 +267,13 @@ mod tests { #[test] fn test_update_and_output_with_sorted_groups() { - let mut a = ApproxCountDistinct::new(DataType::Int64, 0); + let mut a = ApproxCountDistinct::new( + DataType::Int64, + 0, + Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ), + ); let data_chunk = generate_data_chunk(30001, 0); let mut builder = ArrayBuilderImpl::Int64(I64ArrayBuilder::new(5)); diff --git a/src/expr/src/vector_op/agg/count_star.rs b/src/expr/src/vector_op/agg/count_star.rs index d832eb09c2724..00117990cf373 100644 --- a/src/expr/src/vector_op/agg/count_star.rs +++ b/src/expr/src/vector_op/agg/count_star.rs @@ -49,11 +49,8 @@ impl Aggregator for CountStar { .eval(input)? .iter() .filter(|res| { - if let Some(ScalarRefImpl::Bool(v)) = res { - *v - } else { - false - } + res.map(|x| *x.into_scalar_impl().as_bool()) + .unwrap_or(false) }) .count(); @@ -83,7 +80,7 @@ impl Aggregator for CountStar { }; // The first element continues the same group in `self.result`. The following // groups' sizes are simply distance between group start indices. The distance - // between last element and `input.cardinality()` is the ongoing group that + // between last element and `filter_cnt` is the ongoing group that // may continue in following chunks. // // Since the number of groups in an output chunk is limited, if we reach the limit @@ -95,11 +92,8 @@ impl Aggregator for CountStar { .eval(input)? .iter() .filter(|res| { - if let Some(ScalarRefImpl::Bool(v)) = res { - *v - } else { - false - } + res.map(|x| *x.into_scalar_impl().as_bool()) + .unwrap_or(false) }) .count(); if let Some(first) = groups_iter.next() { From 081691a8295fb69014d93b15d4d20784adabb2a6 Mon Sep 17 00:00:00 2001 From: mgt Date: Thu, 7 Jul 2022 21:54:18 +0800 Subject: [PATCH 06/14] add general agg and fix approx count --- src/expr/src/vector_op/agg/aggregator.rs | 11 +++- .../vector_op/agg/approx_count_distinct.rs | 2 +- src/expr/src/vector_op/agg/general_agg.rs | 60 +++++++++++++++---- .../src/vector_op/agg/general_distinct_agg.rs | 13 +++- .../vector_op/agg/general_sorted_grouper.rs | 21 ++++++- 5 files changed, 87 insertions(+), 20 deletions(-) diff --git a/src/expr/src/vector_op/agg/aggregator.rs b/src/expr/src/vector_op/agg/aggregator.rs index 2c54b9096b232..4ff8aff2be76b 100644 --- a/src/expr/src/vector_op/agg/aggregator.rs +++ b/src/expr/src/vector_op/agg/aggregator.rs @@ -121,7 +121,7 @@ impl AggStateFactory { Ok(Box::new(ApproxCountDistinct::new( self.return_type.clone(), self.input_col_idx, - self.filter.clone() + self.filter.clone(), ))) } else if let Some(input_type) = self.input_type.clone() { create_agg_state_unary( @@ -130,6 +130,7 @@ impl AggStateFactory { &self.agg_kind, self.return_type.clone(), self.distinct, + self.filter.clone(), ) } else { Ok(Box::new(CountStar::new( @@ -151,6 +152,7 @@ pub fn create_agg_state_unary( agg_type: &AggKind, return_type: DataType, distinct: bool, + filter: ExpressionRef, ) -> Result> { use crate::expr::data_types::*; @@ -169,6 +171,7 @@ pub fn create_agg_state_unary( input_col_idx, $fn, $init_result, + filter )) }, ($in! { type_match_pattern }, AggKind::$agg, $ret! { type_match_pattern }, true) => { @@ -269,7 +272,9 @@ mod tests { let decimal_type = DataType::Decimal; let bool_type = DataType::Boolean; let char_type = DataType::Varchar; - + let filter: ExpressionRef = Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ); macro_rules! test_create { ($input_type:expr, $agg:ident, $return_type:expr, $expected:ident) => { assert!(create_agg_state_unary( @@ -278,6 +283,7 @@ mod tests { &AggKind::$agg, $return_type.clone(), false, + filter.clone(), ) .$expected()); assert!(create_agg_state_unary( @@ -286,6 +292,7 @@ mod tests { &AggKind::$agg, $return_type.clone(), true, + filter.clone(), ) .$expected()); }; diff --git a/src/expr/src/vector_op/agg/approx_count_distinct.rs b/src/expr/src/vector_op/agg/approx_count_distinct.rs index 198d2b83d1c42..62ca529a6a5df 100644 --- a/src/expr/src/vector_op/agg/approx_count_distinct.rs +++ b/src/expr/src/vector_op/agg/approx_count_distinct.rs @@ -188,7 +188,7 @@ impl Aggregator for ApproxCountDistinct { let mut group_cnt = 0; let mut groups_iter = groups.starting_indices().iter().peekable(); let chunk_offset = groups.chunk_offset(); - for (row_id, (i, datum_ref)) in array.iter().enumerate().skip(chunk_offset).enumerate() { + for (i, (row_id, datum_ref)) in array.iter().enumerate().skip(chunk_offset).enumerate() { // reset state and output result when new group is found if groups_iter.peek() == Some(&&i) { groups_iter.next(); diff --git a/src/expr/src/vector_op/agg/general_agg.rs b/src/expr/src/vector_op/agg/general_agg.rs index fc7cb83221289..1b5e539d86a58 100644 --- a/src/expr/src/vector_op/agg/general_agg.rs +++ b/src/expr/src/vector_op/agg/general_agg.rs @@ -18,6 +18,7 @@ use risingwave_common::array::*; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::*; +use crate::expr::ExpressionRef; use crate::vector_op::agg::aggregator::Aggregator; use crate::vector_op::agg::functions::RTFn; use crate::vector_op::agg::general_sorted_grouper::EqGroups; @@ -32,6 +33,7 @@ where input_col_idx: usize, result: Option, f: F, + filter: ExpressionRef, _phantom: PhantomData, } impl GeneralAgg @@ -45,12 +47,14 @@ where input_col_idx: usize, f: F, init_result: Option, + filter: ExpressionRef, ) -> Self { Self { return_type, input_col_idx, result: init_result, f, + filter, _phantom: PhantomData, } } @@ -67,10 +71,13 @@ where Ok(()) } - pub(super) fn update_concrete(&mut self, input: &T) -> Result<()> { + pub(super) fn update_concrete(&mut self, array: &T, input: &DataChunk) -> Result<()> { let mut cur = self.result.as_ref().map(|x| x.as_scalar_ref()); - for datum in input.iter() { - cur = self.f.eval(cur, datum)?; + for row_id in 0..array.len() { + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + cur = self.f.eval(cur, array.value_at(row_id))?; + } } let r = cur.map(|x| x.to_owned_scalar()); self.result = r; @@ -85,23 +92,26 @@ where pub(super) fn update_and_output_with_sorted_groups_concrete( &mut self, - input: &T, + array: &T, builder: &mut R::Builder, groups: &EqGroups, + input: &DataChunk, ) -> Result<()> { let mut group_cnt = 0; let mut groups_iter = groups.starting_indices().iter().peekable(); let mut cur = self.result.as_ref().map(|x| x.as_scalar_ref()); let chunk_offset = groups.chunk_offset(); - for (i, v) in input.iter().skip(chunk_offset).enumerate() { + for (i, (row_id, v)) in array.iter().enumerate().skip(chunk_offset).enumerate() { if groups_iter.peek() == Some(&&(i + chunk_offset)) { groups_iter.next(); group_cnt += 1; builder.append(cur)?; cur = None; } - cur = self.f.eval(cur, v)?; - + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + cur = self.f.eval(cur, v)?; + } // reset state and exit when reach limit if groups.is_reach_limit(group_cnt) { cur = None; @@ -111,6 +121,19 @@ where self.result = cur.map(|x| x.to_owned_scalar()); Ok(()) } + + fn apply_filter_on_row(&self, input: &DataChunk, row_id: usize) -> Result { + let (row, visible) = input.row_at(row_id)?; + // SAFETY: when performing agg, the data chunk should already be + // compacted. + assert!(visible); + let filter_res = if let Some(ScalarImpl::Bool(v)) = self.filter.eval_row(&Row::from(row))? { + v + } else { + false + }; + Ok(filter_res) + } } macro_rules! impl_aggregator { @@ -127,7 +150,11 @@ macro_rules! impl_aggregator { if let ArrayImpl::$input_variant(i) = input.column_at(self.input_col_idx).array_ref() { - self.update_with_scalar_concrete(i, row_id) + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + self.update_with_scalar_concrete(i, row_id)?; + } + Ok(()) } else { Err(ErrorCode::InternalError(format!( "Input fail to match {}.", @@ -141,7 +168,7 @@ macro_rules! impl_aggregator { if let ArrayImpl::$input_variant(i) = input.column_at(self.input_col_idx).array_ref() { - self.update_concrete(i) + self.update_concrete(i, input) } else { Err(ErrorCode::InternalError(format!( "Input fail to match {}.", @@ -172,7 +199,7 @@ macro_rules! impl_aggregator { if let (ArrayImpl::$input_variant(i), ArrayBuilderImpl::$result_variant(b)) = (input.column_at(self.input_col_idx).array_ref(), builder) { - self.update_and_output_with_sorted_groups_concrete(i, b, groups) + self.update_and_output_with_sorted_groups_concrete(i, b, groups, input) } else { Err(ErrorCode::InternalError(format!( "Input fail to match {} or builder fail to match {}.", @@ -228,7 +255,7 @@ mod tests { use risingwave_common::types::Decimal; use super::*; - use crate::expr::AggKind; + use crate::expr::{AggKind, Expression, LiteralExpression}; use crate::vector_op::agg::aggregator::create_agg_state_unary; fn eval_agg( @@ -240,7 +267,16 @@ mod tests { ) -> Result { let len = input.len(); let input_chunk = DataChunk::new(vec![Column::new(input)], len); - let mut agg_state = create_agg_state_unary(input_type, 0, agg_type, return_type, false)?; + let mut agg_state = create_agg_state_unary( + input_type, + 0, + agg_type, + return_type, + false, + Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ), + )?; agg_state.update(&input_chunk)?; agg_state.output(&mut builder)?; builder.finish().map_err(Into::into) diff --git a/src/expr/src/vector_op/agg/general_distinct_agg.rs b/src/expr/src/vector_op/agg/general_distinct_agg.rs index e3fe7af8f482f..24fb2d4a5fe9b 100644 --- a/src/expr/src/vector_op/agg/general_distinct_agg.rs +++ b/src/expr/src/vector_op/agg/general_distinct_agg.rs @@ -244,7 +244,7 @@ mod tests { use risingwave_common::types::Decimal; use super::*; - use crate::expr::AggKind; + use crate::expr::{AggKind, LiteralExpression, Expression}; use crate::vector_op::agg::aggregator::create_agg_state_unary; fn eval_agg( @@ -256,7 +256,16 @@ mod tests { ) -> Result { let len = input.len(); let input_chunk = DataChunk::new(vec![Column::new(input)], len); - let mut agg_state = create_agg_state_unary(input_type, 0, agg_type, return_type, true)?; + let mut agg_state = create_agg_state_unary( + input_type, + 0, + agg_type, + return_type, + true, + Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ), + )?; agg_state.update(&input_chunk)?; agg_state.output(&mut builder)?; builder.finish().map_err(Into::into) diff --git a/src/expr/src/vector_op/agg/general_sorted_grouper.rs b/src/expr/src/vector_op/agg/general_sorted_grouper.rs index 6c2ddb540ebae..5fa117c2e1630 100644 --- a/src/expr/src/vector_op/agg/general_sorted_grouper.rs +++ b/src/expr/src/vector_op/agg/general_sorted_grouper.rs @@ -313,6 +313,7 @@ mod tests { use risingwave_pb::expr::AggCall; use super::*; + use crate::expr::{LiteralExpression, Expression}; use crate::vector_op::agg::functions::*; use crate::vector_op::agg::general_agg::GeneralAgg; use crate::vector_op::agg::AggStateFactory; @@ -358,7 +359,15 @@ mod tests { let mut g0_builder = I32ArrayBuilder::new(0); let mut g1 = GeneralSortedGrouper::::new(false, None); let mut g1_builder = I32ArrayBuilder::new(0); - let mut a = GeneralAgg::::new(DataType::Int64, 0, sum, None); + let mut a = GeneralAgg::::new( + DataType::Int64, + 0, + sum, + None, + Arc::from( + LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(), + ), + ); let mut a_builder = I64ArrayBuilder::new(0); let g0_input = I32Array::from_slice(&[Some(1), Some(1), Some(3)]).unwrap(); @@ -369,7 +378,13 @@ mod tests { g0.update_and_output_with_sorted_groups_concrete(&g0_input, &mut g0_builder, &eq)?; g1.update_and_output_with_sorted_groups_concrete(&g1_input, &mut g1_builder, &eq)?; let a_input = I32Array::from_slice(&[Some(1), Some(2), Some(3)]).unwrap(); - a.update_and_output_with_sorted_groups_concrete(&a_input, &mut a_builder, &eq)?; + let input_chunk = DataChunk::from_pretty(" + i + 1 + 2 + 3 + "); + a.update_and_output_with_sorted_groups_concrete(&a_input, &mut a_builder, &eq, &input_chunk)?; let g0_input = I32Array::from_slice(&[Some(3), Some(4), Some(4)]).unwrap(); let eq0 = g0.detect_groups_concrete(&g0_input)?; @@ -379,7 +394,7 @@ mod tests { g0.update_and_output_with_sorted_groups_concrete(&g0_input, &mut g0_builder, &eq)?; g1.update_and_output_with_sorted_groups_concrete(&g1_input, &mut g1_builder, &eq)?; let a_input = I32Array::from_slice(&[Some(1), Some(2), Some(3)]).unwrap(); - a.update_and_output_with_sorted_groups_concrete(&a_input, &mut a_builder, &eq)?; + a.update_and_output_with_sorted_groups_concrete(&a_input, &mut a_builder, &eq, &input_chunk)?; g0.output_concrete(&mut g0_builder)?; g1.output_concrete(&mut g1_builder)?; From 950a396ed37e576191f61f5a4e8bc7eb50dd520f Mon Sep 17 00:00:00 2001 From: mgt Date: Fri, 8 Jul 2022 13:42:01 +0800 Subject: [PATCH 07/14] add general distinct agg --- src/expr/src/vector_op/agg/aggregator.rs | 1 + .../src/vector_op/agg/general_distinct_agg.rs | 71 ++++++++++++++----- 2 files changed, 56 insertions(+), 16 deletions(-) diff --git a/src/expr/src/vector_op/agg/aggregator.rs b/src/expr/src/vector_op/agg/aggregator.rs index 4ff8aff2be76b..907b04ef12e46 100644 --- a/src/expr/src/vector_op/agg/aggregator.rs +++ b/src/expr/src/vector_op/agg/aggregator.rs @@ -179,6 +179,7 @@ pub fn create_agg_state_unary( return_type, input_col_idx, $fn, + filter, )) }, )* diff --git a/src/expr/src/vector_op/agg/general_distinct_agg.rs b/src/expr/src/vector_op/agg/general_distinct_agg.rs index 24fb2d4a5fe9b..da7f89af64c17 100644 --- a/src/expr/src/vector_op/agg/general_distinct_agg.rs +++ b/src/expr/src/vector_op/agg/general_distinct_agg.rs @@ -19,6 +19,7 @@ use risingwave_common::array::*; use risingwave_common::error::{ErrorCode, Result}; use risingwave_common::types::*; +use crate::expr::ExpressionRef; use crate::vector_op::agg::aggregator::Aggregator; use crate::vector_op::agg::functions::RTFn; use crate::vector_op::agg::general_sorted_grouper::EqGroups; @@ -40,6 +41,7 @@ where result: Option, f: F, exists: HashSet, + filter: ExpressionRef, _phantom: PhantomData, } impl GeneralDistinctAgg @@ -48,7 +50,7 @@ where F: for<'a> RTFn<'a, T, R>, R: Array, { - pub fn new(return_type: DataType, input_col_idx: usize, f: F) -> Self { + pub fn new(return_type: DataType, input_col_idx: usize, f: F, filter: ExpressionRef) -> Self { Self { return_type, input_col_idx, @@ -56,6 +58,7 @@ where f, exists: HashSet::new(), _phantom: PhantomData, + filter, } } @@ -76,13 +79,27 @@ where Ok(()) } - fn update_concrete(&mut self, input: &T) -> Result<()> { - let input = input.iter().filter(|scalar_ref| { - self.exists - .insert(scalar_ref.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value())) - }); + fn update_concrete(&mut self, array: &T, input: &DataChunk) -> Result<()> { + let filtered_data: Vec<_> = array + .iter() + .enumerate() + .filter_map(|(row_id, datum)| { + if self + .apply_filter_on_row(input, row_id) + .ok() + .unwrap_or(false) + && self.exists.insert( + datum.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value()), + ) + { + return Some(row_id); + } + None + }) + .collect(); let mut cur = self.result.as_ref().map(|x| x.as_scalar_ref()); - for datum in input { + for row_id in filtered_data { + let datum = array.value_at(row_id); cur = self.f.eval(cur, datum)?; } let r = cur.map(|x| x.to_owned_scalar()); @@ -98,24 +115,29 @@ where fn update_and_output_with_sorted_groups_concrete( &mut self, - input: &T, + array: &T, builder: &mut R::Builder, groups: &EqGroups, + input: &DataChunk, ) -> Result<()> { let mut group_cnt = 0; let mut groups_iter = groups.starting_indices().iter().peekable(); let mut cur = self.result.as_ref().map(|x| x.as_scalar_ref()); let chunk_offset = groups.chunk_offset(); - for (i, v) in input.iter().skip(chunk_offset).enumerate() { + for (i, (row_id, v)) in array.iter().enumerate().skip(chunk_offset).enumerate() { if groups_iter.peek() == Some(&&i) { groups_iter.next(); group_cnt += 1; builder.append(cur)?; cur = None; } - let scalar_impl = v.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value()); - if self.exists.insert(scalar_impl) { - cur = self.f.eval(cur, v)?; + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + let scalar_impl = + v.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value()); + if self.exists.insert(scalar_impl) { + cur = self.f.eval(cur, v)?; + } } // reset state and exit when reach limit @@ -127,6 +149,19 @@ where self.result = cur.map(|x| x.to_owned_scalar()); Ok(()) } + + fn apply_filter_on_row(&self, input: &DataChunk, row_id: usize) -> Result { + let (row, visible) = input.row_at(row_id)?; + // SAFETY: when performing agg, the data chunk should already be + // compacted. + assert!(visible); + let filter_res = if let Some(ScalarImpl::Bool(v)) = self.filter.eval_row(&Row::from(row))? { + v + } else { + false + }; + Ok(filter_res) + } } macro_rules! impl_aggregator { @@ -143,7 +178,11 @@ macro_rules! impl_aggregator { if let ArrayImpl::$input_variant(i) = input.column_at(self.input_col_idx).array_ref() { - self.update_with_scalar_concrete(i, row_id) + let filter_res = self.apply_filter_on_row(input, row_id)?; + if filter_res { + self.update_with_scalar_concrete(i, row_id)?; + } + Ok(()) } else { Err(ErrorCode::InternalError(format!( "Input fail to match {}.", @@ -157,7 +196,7 @@ macro_rules! impl_aggregator { if let ArrayImpl::$input_variant(i) = input.column_at(self.input_col_idx).array_ref() { - self.update_concrete(i) + self.update_concrete(i, input) } else { Err(ErrorCode::InternalError(format!( "Input fail to match {}.", @@ -188,7 +227,7 @@ macro_rules! impl_aggregator { if let (ArrayImpl::$input_variant(i), ArrayBuilderImpl::$result_variant(b)) = (input.column_at(self.input_col_idx).array_ref(), builder) { - self.update_and_output_with_sorted_groups_concrete(i, b, groups) + self.update_and_output_with_sorted_groups_concrete(i, b, groups, input) } else { Err(ErrorCode::InternalError(format!( "Input fail to match {} or builder fail to match {}.", @@ -244,7 +283,7 @@ mod tests { use risingwave_common::types::Decimal; use super::*; - use crate::expr::{AggKind, LiteralExpression, Expression}; + use crate::expr::{AggKind, Expression, LiteralExpression}; use crate::vector_op::agg::aggregator::create_agg_state_unary; fn eval_agg( From e35037f3e0cfd09346b9d2c6301a6c992c38a755 Mon Sep 17 00:00:00 2001 From: mgt Date: Fri, 8 Jul 2022 17:51:47 +0800 Subject: [PATCH 08/14] format --- .../vector_op/agg/general_sorted_grouper.rs | 22 ++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/expr/src/vector_op/agg/general_sorted_grouper.rs b/src/expr/src/vector_op/agg/general_sorted_grouper.rs index 5fa117c2e1630..996a868a38e18 100644 --- a/src/expr/src/vector_op/agg/general_sorted_grouper.rs +++ b/src/expr/src/vector_op/agg/general_sorted_grouper.rs @@ -313,7 +313,7 @@ mod tests { use risingwave_pb::expr::AggCall; use super::*; - use crate::expr::{LiteralExpression, Expression}; + use crate::expr::{Expression, LiteralExpression}; use crate::vector_op::agg::functions::*; use crate::vector_op::agg::general_agg::GeneralAgg; use crate::vector_op::agg::AggStateFactory; @@ -378,13 +378,20 @@ mod tests { g0.update_and_output_with_sorted_groups_concrete(&g0_input, &mut g0_builder, &eq)?; g1.update_and_output_with_sorted_groups_concrete(&g1_input, &mut g1_builder, &eq)?; let a_input = I32Array::from_slice(&[Some(1), Some(2), Some(3)]).unwrap(); - let input_chunk = DataChunk::from_pretty(" + let input_chunk = DataChunk::from_pretty( + " i 1 2 3 - "); - a.update_and_output_with_sorted_groups_concrete(&a_input, &mut a_builder, &eq, &input_chunk)?; + ", + ); + a.update_and_output_with_sorted_groups_concrete( + &a_input, + &mut a_builder, + &eq, + &input_chunk, + )?; let g0_input = I32Array::from_slice(&[Some(3), Some(4), Some(4)]).unwrap(); let eq0 = g0.detect_groups_concrete(&g0_input)?; @@ -394,7 +401,12 @@ mod tests { g0.update_and_output_with_sorted_groups_concrete(&g0_input, &mut g0_builder, &eq)?; g1.update_and_output_with_sorted_groups_concrete(&g1_input, &mut g1_builder, &eq)?; let a_input = I32Array::from_slice(&[Some(1), Some(2), Some(3)]).unwrap(); - a.update_and_output_with_sorted_groups_concrete(&a_input, &mut a_builder, &eq, &input_chunk)?; + a.update_and_output_with_sorted_groups_concrete( + &a_input, + &mut a_builder, + &eq, + &input_chunk, + )?; g0.output_concrete(&mut g0_builder)?; g1.output_concrete(&mut g1_builder)?; From b3cae56df734591ac5095698accd351d2ac0be51 Mon Sep 17 00:00:00 2001 From: mgt Date: Fri, 8 Jul 2022 22:02:51 +0800 Subject: [PATCH 09/14] add e2e test --- .../batch/aggregate/selective_count.slt.part | 6 ++-- .../aggregate/selective_general_agg.slt.part | 29 +++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) create mode 100644 e2e_test/batch/aggregate/selective_general_agg.slt.part diff --git a/e2e_test/batch/aggregate/selective_count.slt.part b/e2e_test/batch/aggregate/selective_count.slt.part index 30ae47f7033ca..3accedba35eff 100644 --- a/e2e_test/batch/aggregate/selective_count.slt.part +++ b/e2e_test/batch/aggregate/selective_count.slt.part @@ -13,12 +13,12 @@ select count(*) filter (where v1 > 2) from t 2 query I -select count(*) filter (where v1 <= v2 and v2 <= v3) from t; +select count(*) filter (where v1 <= v2 and v2 <= v3) from t ---- 2 -query I -select v2, count(*) filter (where v1 > 1) as cnt from t group by v2 order by v2 ; +query II +select v2, count(*) filter (where v1 > 1) as cnt from t group by v2 order by v2 ---- 3 2 4 1 diff --git a/e2e_test/batch/aggregate/selective_general_agg.slt.part b/e2e_test/batch/aggregate/selective_general_agg.slt.part new file mode 100644 index 0000000000000..7e283d35302ea --- /dev/null +++ b/e2e_test/batch/aggregate/selective_general_agg.slt.part @@ -0,0 +1,29 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +create table t(v1 int, v2 int) + +statement ok +insert into t values (2, 300), (1, 200), (3, 200), (3, 500), (2, 100), (1, 500), (6, 300) + +query II +select min(v1) filter (where v1 * 100 >= v2), min(v1) filter (where v1 * 100 < v2) from t +---- +2 1 + +query II +select v2, sum(v1) filter(where v1 >= 2) from t group by v2 order by v2 +---- +100 2 +200 3 +300 8 +500 3 + +query I +with sub(a, b) as (select sum(v1), count(v2) filter (where v1 > 5) from t) select b from sub +---- +1 + +statement ok +drop table t From 1c27bd3aa16d03dd50739c393e7b2e7aba94ca58 Mon Sep 17 00:00:00 2001 From: mgt Date: Fri, 8 Jul 2022 22:13:31 +0800 Subject: [PATCH 10/14] add selective distinct agg e2e --- .../aggregate/selective_distinct_agg.slt.part | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 e2e_test/batch/aggregate/selective_distinct_agg.slt.part diff --git a/e2e_test/batch/aggregate/selective_distinct_agg.slt.part b/e2e_test/batch/aggregate/selective_distinct_agg.slt.part new file mode 100644 index 0000000000000..a4c7e84fbb705 --- /dev/null +++ b/e2e_test/batch/aggregate/selective_distinct_agg.slt.part @@ -0,0 +1,22 @@ +statement ok +SET RW_IMPLICIT_FLUSH TO true; + +statement ok +create table t(v1 int, v2 int, v3 int) + +statement ok +insert into t values (2, 200, -1), (1, 200, -2), (3, 100, -2), (3, 300, -1), (2, 100, -1), (1, 500, -2), (6, 500, -1) + +query I +select sum(distinct v1) filter (where v1 * 100 >= v2) from t +---- +11 + +query II +select v3, sum(distinct v1) filter(where v1 * 100 >= v2) from t group by v3 order by v3 +---- +-2 3 +-1 11 + +statement ok +drop table t From d5f39dc12c309a1106cd9ddad033dcd512fd0fae Mon Sep 17 00:00:00 2001 From: mgt Date: Sun, 10 Jul 2022 17:11:41 +0800 Subject: [PATCH 11/14] avoid collect --- .../src/vector_op/agg/general_distinct_agg.rs | 28 ++++++------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/src/expr/src/vector_op/agg/general_distinct_agg.rs b/src/expr/src/vector_op/agg/general_distinct_agg.rs index da7f89af64c17..d156fcb3fbd52 100644 --- a/src/expr/src/vector_op/agg/general_distinct_agg.rs +++ b/src/expr/src/vector_op/agg/general_distinct_agg.rs @@ -80,27 +80,15 @@ where } fn update_concrete(&mut self, array: &T, input: &DataChunk) -> Result<()> { - let filtered_data: Vec<_> = array - .iter() - .enumerate() - .filter_map(|(row_id, datum)| { - if self - .apply_filter_on_row(input, row_id) - .ok() - .unwrap_or(false) - && self.exists.insert( - datum.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value()), - ) - { - return Some(row_id); - } - None - }) - .collect(); let mut cur = self.result.as_ref().map(|x| x.as_scalar_ref()); - for row_id in filtered_data { - let datum = array.value_at(row_id); - cur = self.f.eval(cur, datum)?; + for (row_id, datum) in array.iter().enumerate() { + if self.apply_filter_on_row(input, row_id)? + && self + .exists + .insert(datum.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value())) + { + cur = self.f.eval(cur, datum)?; + } } let r = cur.map(|x| x.to_owned_scalar()); self.result = r; From 9c3e1212b0b560a9206f87cadb45d8e96a609131 Mon Sep 17 00:00:00 2001 From: mgt Date: Wed, 13 Jul 2022 16:15:02 +0800 Subject: [PATCH 12/14] fix count star --- src/expr/src/vector_op/agg/count_star.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/expr/src/vector_op/agg/count_star.rs b/src/expr/src/vector_op/agg/count_star.rs index 47f440c50bd05..ac8cbc5f687d8 100644 --- a/src/expr/src/vector_op/agg/count_star.rs +++ b/src/expr/src/vector_op/agg/count_star.rs @@ -71,12 +71,12 @@ impl Aggregator for CountStar { .filter .eval(input)? .iter() + .skip(start_row_id) + .take(end_row_id - start_row_id) .filter(|res| { res.map(|x| *x.into_scalar_impl().as_bool()) .unwrap_or(false) }) - .skip(start_row_id) - .take(end_row_id - start_row_id) .count(); } Ok(()) From c2c193889115419706a29ecd7886f1476d769fec Mon Sep 17 00:00:00 2001 From: mgt Date: Wed, 13 Jul 2022 16:30:21 +0800 Subject: [PATCH 13/14] format --- src/expr/src/vector_op/agg/general_agg.rs | 2 +- src/expr/src/vector_op/agg/general_distinct_agg.rs | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/expr/src/vector_op/agg/general_agg.rs b/src/expr/src/vector_op/agg/general_agg.rs index f6f56ba7e43ba..5f63f40218510 100644 --- a/src/expr/src/vector_op/agg/general_agg.rs +++ b/src/expr/src/vector_op/agg/general_agg.rs @@ -111,7 +111,7 @@ where } /// `apply_filter_on_row` apply a filter on the given row, and return if the row satisfies the - /// filter or not + /// filter or not /// # SAFETY /// the given row must be visible fn apply_filter_on_row(&self, input: &DataChunk, row_id: usize) -> Result { diff --git a/src/expr/src/vector_op/agg/general_distinct_agg.rs b/src/expr/src/vector_op/agg/general_distinct_agg.rs index a46d0dfe11f34..e979c177add5b 100644 --- a/src/expr/src/vector_op/agg/general_distinct_agg.rs +++ b/src/expr/src/vector_op/agg/general_distinct_agg.rs @@ -87,12 +87,12 @@ where ) -> Result<()> { let mut cur = self.result.as_ref().map(|x| x.as_scalar_ref()); for row_id in start_row_id..end_row_id { - if self.apply_filter_on_row(input, row_id)? - { + if self.apply_filter_on_row(input, row_id)? { let datum = array.value_at(row_id); if self .exists - .insert(datum.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value())) { + .insert(datum.map(|scalar_ref| scalar_ref.to_owned_scalar().to_scalar_value())) + { cur = self.f.eval(cur, datum)?; } } From 91c83feeed56c712c7a5d12b4799d5ff2fd39b0c Mon Sep 17 00:00:00 2001 From: mgt Date: Wed, 13 Jul 2022 22:14:53 +0800 Subject: [PATCH 14/14] fix bug in optimizer --- src/frontend/src/optimizer/rule/distinct_agg.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/frontend/src/optimizer/rule/distinct_agg.rs b/src/frontend/src/optimizer/rule/distinct_agg.rs index 8dcbb0d933d22..b859dc81e183e 100644 --- a/src/frontend/src/optimizer/rule/distinct_agg.rs +++ b/src/frontend/src/optimizer/rule/distinct_agg.rs @@ -161,7 +161,7 @@ impl DistinctAgg { ExprType::GreaterThan, vec![ InputRef::new(index_of_middle_agg, DataType::Int64).into(), - Literal::new(Some(0.into()), DataType::Int64).into(), + Literal::new(Some(0_i64.into()), DataType::Int64).into(), ], ) .unwrap();