Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(batch): support selective aggregation #3683

Merged
merged 20 commits into from
Jul 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions e2e_test/batch/aggregate/selective_count.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

statement ok
create table t (v1 int not null, v2 int not null, v3 int not null)

statement ok
insert into t values (1,4,2), (2,3,3), (3,4,4), (4,3,5)

query I
select count(*) filter (where v1 > 2) from t
----
2

query I
select count(*) filter (where v1 <= v2 and v2 <= v3) from t
----
2

query II
select v2, count(*) filter (where v1 > 1) as cnt from t group by v2 order by v2
----
3 2
4 1


statement ok
drop table t
22 changes: 22 additions & 0 deletions e2e_test/batch/aggregate/selective_distinct_agg.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

statement ok
create table t(v1 int, v2 int, v3 int)

statement ok
insert into t values (2, 200, -1), (1, 200, -2), (3, 100, -2), (3, 300, -1), (2, 100, -1), (1, 500, -2), (6, 500, -1)

query I
select sum(distinct v1) filter (where v1 * 100 >= v2) from t
----
11

query II
select v3, sum(distinct v1) filter(where v1 * 100 >= v2) from t group by v3 order by v3
----
-2 3
-1 11

statement ok
drop table t
29 changes: 29 additions & 0 deletions e2e_test/batch/aggregate/selective_general_agg.slt.part
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
statement ok
SET RW_IMPLICIT_FLUSH TO true;

statement ok
create table t(v1 int, v2 int)

statement ok
insert into t values (2, 300), (1, 200), (3, 200), (3, 500), (2, 100), (1, 500), (6, 300)

query II
select min(v1) filter (where v1 * 100 >= v2), min(v1) filter (where v1 * 100 < v2) from t
----
2 1

query II
select v2, sum(v1) filter(where v1 >= 2) from t group by v2 order by v2
----
100 2
200 3
300 8
500 3

query I
with sub(a, b) as (select sum(v1), count(v2) filter (where v1 > 5) from t) select b from sub
----
1

statement ok
drop table t
1 change: 1 addition & 0 deletions proto/expr.proto
Original file line number Diff line number Diff line change
Expand Up @@ -153,4 +153,5 @@ message AggCall {
repeated Arg args = 2;
data.DataType return_type = 3;
bool distinct = 4;
ExprNode filter = 5;
}
2 changes: 2 additions & 0 deletions src/batch/src/executor/hash_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,7 @@ mod tests {
..Default::default()
}),
distinct: false,
filter: None,
};

let agg_prost = HashAggNode {
Expand Down Expand Up @@ -375,6 +376,7 @@ mod tests {
..Default::default()
}),
distinct: false,
filter: None,
};

let agg_prost = HashAggNode {
Expand Down
5 changes: 5 additions & 0 deletions src/batch/src/executor/sort_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ mod tests {
..Default::default()
}),
distinct: false,
filter: None,
};

let count_star = AggStateFactory::new(&prost)?.create_agg_state()?;
Expand Down Expand Up @@ -424,6 +425,7 @@ mod tests {
..Default::default()
}),
distinct: false,
filter: None,
};

let count_star = AggStateFactory::new(&prost)?.create_agg_state()?;
Expand Down Expand Up @@ -551,6 +553,7 @@ mod tests {
..Default::default()
}),
distinct: false,
filter: None,
};

let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?;
Expand Down Expand Up @@ -634,6 +637,7 @@ mod tests {
..Default::default()
}),
distinct: false,
filter: None,
};

let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?;
Expand Down Expand Up @@ -756,6 +760,7 @@ mod tests {
..Default::default()
}),
distinct: false,
filter: None,
};

let sum_agg = AggStateFactory::new(&prost)?.create_agg_state()?;
Expand Down
30 changes: 27 additions & 3 deletions src/expr/src/vector_op/agg/aggregator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::sync::Arc;

use risingwave_common::array::*;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::types::*;
use risingwave_pb::expr::AggCall;

use crate::expr::AggKind;
use crate::expr::{build_from_prost, AggKind, Expression, ExpressionRef, LiteralExpression};
use crate::vector_op::agg::approx_count_distinct::ApproxCountDistinct;
use crate::vector_op::agg::count_star::CountStar;
use crate::vector_op::agg::functions::*;
Expand Down Expand Up @@ -55,13 +57,20 @@ pub struct AggStateFactory {
agg_kind: AggKind,
return_type: DataType,
distinct: bool,
filter: ExpressionRef,
}

impl AggStateFactory {
pub fn new(prost: &AggCall) -> Result<Self> {
let return_type = DataType::from(prost.get_return_type()?);
let agg_kind = AggKind::try_from(prost.get_type()?)?;
let distinct = prost.distinct;
let filter: ExpressionRef = match prost.filter {
Some(ref expr) => Arc::from(build_from_prost(expr)?),
None => Arc::from(
LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(),
),
};
match &prost.get_args()[..] {
[ref arg] => {
let input_type = DataType::from(arg.get_type()?);
Expand All @@ -72,6 +81,7 @@ impl AggStateFactory {
agg_kind,
return_type,
distinct,
filter,
})
}
[] => match (&agg_kind, return_type.clone()) {
Expand All @@ -81,6 +91,7 @@ impl AggStateFactory {
agg_kind,
return_type,
distinct,
filter,
}),
_ => Err(ErrorCode::InternalError(format!(
"Agg {:?} without args not supported",
Expand All @@ -101,6 +112,7 @@ impl AggStateFactory {
Ok(Box::new(ApproxCountDistinct::new(
self.return_type.clone(),
self.input_col_idx,
self.filter.clone(),
)))
} else if let Some(input_type) = self.input_type.clone() {
create_agg_state_unary(
Expand All @@ -109,9 +121,14 @@ impl AggStateFactory {
&self.agg_kind,
self.return_type.clone(),
self.distinct,
self.filter.clone(),
)
} else {
Ok(Box::new(CountStar::new(self.return_type.clone())))
Ok(Box::new(CountStar::new(
self.return_type.clone(),
0,
self.filter.clone(),
)))
}
}

Expand All @@ -126,6 +143,7 @@ pub fn create_agg_state_unary(
agg_type: &AggKind,
return_type: DataType,
distinct: bool,
filter: ExpressionRef,
) -> Result<Box<dyn Aggregator>> {
use crate::expr::data_types::*;

Expand All @@ -144,13 +162,15 @@ pub fn create_agg_state_unary(
input_col_idx,
$fn,
$init_result,
filter
))
},
($in! { type_match_pattern }, AggKind::$agg, $ret! { type_match_pattern }, true) => {
Box::new(GeneralDistinctAgg::<$in! { type_array }, _, $ret! { type_array }>::new(
return_type,
input_col_idx,
$fn,
filter,
))
},
)*
Expand Down Expand Up @@ -244,7 +264,9 @@ mod tests {
let decimal_type = DataType::Decimal;
let bool_type = DataType::Boolean;
let char_type = DataType::Varchar;

let filter: ExpressionRef = Arc::from(
LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(),
);
macro_rules! test_create {
($input_type:expr, $agg:ident, $return_type:expr, $expected:ident) => {
assert!(create_agg_state_unary(
Expand All @@ -253,6 +275,7 @@ mod tests {
&AggKind::$agg,
$return_type.clone(),
false,
filter.clone(),
)
.$expected());
assert!(create_agg_state_unary(
Expand All @@ -261,6 +284,7 @@ mod tests {
&AggKind::$agg,
$return_type.clone(),
true,
filter.clone(),
)
.$expected());
};
Expand Down
59 changes: 46 additions & 13 deletions src/expr/src/vector_op/agg/approx_count_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use risingwave_common::array::*;
use risingwave_common::error::{ErrorCode, Result};
use risingwave_common::types::*;

use crate::expr::ExpressionRef;
use crate::vector_op::agg::aggregator::Aggregator;

const INDEX_BITS: u8 = 14; // number of bits used for finding the index of each 64-bit hash
Expand All @@ -36,14 +37,16 @@ pub struct ApproxCountDistinct {
return_type: DataType,
input_col_idx: usize,
registers: [u8; NUM_OF_REGISTERS],
filter: ExpressionRef,
}

impl ApproxCountDistinct {
pub fn new(return_type: DataType, input_col_idx: usize) -> Self {
pub fn new(return_type: DataType, input_col_idx: usize, filter: ExpressionRef) -> Self {
Self {
return_type,
input_col_idx,
registers: [0; NUM_OF_REGISTERS],
filter,
}
}

Expand Down Expand Up @@ -114,6 +117,20 @@ impl ApproxCountDistinct {

answer as i64
}

/// `apply_filter_on_row` apply a filter on the given row, and return if the row satisfies the
/// filter or not # SAFETY
/// the given row must be visible
fn apply_filter_on_row(&self, input: &DataChunk, row_id: usize) -> Result<bool> {
let (row, visible) = input.row_at(row_id)?;
assert!(visible);
let filter_res = if let Some(ScalarImpl::Bool(v)) = self.filter.eval_row(&Row::from(row))? {
v
} else {
false
};
Ok(filter_res)
}
}

impl Aggregator for ApproxCountDistinct {
Expand All @@ -122,9 +139,12 @@ impl Aggregator for ApproxCountDistinct {
}

fn update_single(&mut self, input: &DataChunk, row_id: usize) -> Result<()> {
let array = input.column_at(self.input_col_idx).array_ref();
let datum_ref = array.value_at(row_id);
self.add_datum(datum_ref);
let filter_res = self.apply_filter_on_row(input, row_id)?;
if filter_res {
let array = input.column_at(self.input_col_idx).array_ref();
let datum_ref = array.value_at(row_id);
self.add_datum(datum_ref);
}
Ok(())
}

Expand All @@ -135,12 +155,12 @@ impl Aggregator for ApproxCountDistinct {
end_row_id: usize,
) -> Result<()> {
let array = input.column_at(self.input_col_idx).array_ref();
for datum_ref in array
.iter()
.skip(start_row_id)
.take(end_row_id - start_row_id)
{
self.add_datum(datum_ref);
for row_id in start_row_id..end_row_id {
let filter_res = self.apply_filter_on_row(input, row_id)?;
if filter_res {
let datum_ref = array.value_at(row_id);
self.add_datum(datum_ref);
}
}
Ok(())
}
Expand Down Expand Up @@ -168,8 +188,9 @@ mod tests {
use risingwave_common::array::{
ArrayBuilder, ArrayBuilderImpl, DataChunk, I32Array, I64ArrayBuilder,
};
use risingwave_common::types::DataType;
use risingwave_common::types::{DataType, ScalarImpl};

use crate::expr::{Expression, LiteralExpression};
use crate::vector_op::agg::aggregator::Aggregator;
use crate::vector_op::agg::approx_count_distinct::ApproxCountDistinct;

Expand All @@ -193,7 +214,13 @@ mod tests {
let inputs_size: [usize; 3] = [20000, 10000, 5000];
let inputs_start: [i32; 3] = [0, 20000, 30000];

let mut agg = ApproxCountDistinct::new(DataType::Int64, 0);
let mut agg = ApproxCountDistinct::new(
DataType::Int64,
0,
Arc::from(
LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(),
),
);
let mut builder = ArrayBuilderImpl::Int64(I64ArrayBuilder::new(3));

for i in 0..3 {
Expand All @@ -213,7 +240,13 @@ mod tests {
let inputs_size: [usize; 3] = [20000, 10000, 5000];
let inputs_start: [i32; 3] = [0, 20000, 30000];

let mut agg = ApproxCountDistinct::new(DataType::Int64, 0);
let mut agg = ApproxCountDistinct::new(
DataType::Int64,
0,
Arc::from(
LiteralExpression::new(DataType::Boolean, Some(ScalarImpl::Bool(true))).boxed(),
),
);
let mut builder = ArrayBuilderImpl::Int64(I64ArrayBuilder::new(3));

for i in 0..3 {
Expand Down
Loading