Skip to content

Commit

Permalink
feat: support filter_row_groups by EqPruner (#380)
Browse files Browse the repository at this point in the history
  • Loading branch information
ShiKaiWi authored Nov 8, 2022
1 parent 9699e85 commit ce93fbf
Showing 1 changed file with 124 additions and 7 deletions.
131 changes: 124 additions & 7 deletions components/parquet_ext/src/prune/equal.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,72 @@
// Copyright 2022 CeresDB Project Authors. Licensed under Apache-2.0.

use arrow::datatypes::SchemaRef;
use datafusion::{logical_plan::Column, scalar::ScalarValue};
use datafusion_expr::{Expr, Operator};

const MAX_ELEMS_IN_LIST_FOR_FILTER: usize = 100;

/// A position used to describe the location of a column in the row groups.
#[derive(Debug, Clone, Copy)]
pub struct ColumnPosition {
pub row_group_idx: usize,
pub column_idx: usize,
}

/// Filter the row groups according to the `exprs`.
///
/// The return value is the filtered row group indexes. And the `is_equal`
/// closure receive three parameters:
/// - The position of the column in the row groups;
/// - The value of the column used to determine equality;
/// - Whether this compare is negated;
pub fn filter_row_groups<E>(
schema: SchemaRef,
exprs: &[Expr],
row_group_num: usize,
is_equal: E,
) -> Vec<usize>
where
E: Fn(ColumnPosition, &ScalarValue, bool) -> bool,
{
let mut should_reads = vec![true; row_group_num];
for expr in exprs {
let pruner = EqPruner::new(expr);
for (row_group_idx, should_read) in should_reads.iter_mut().enumerate() {
if !*should_read {
continue;
}

let f = |column: &Column, val: &ScalarValue, negated: bool| -> bool {
match schema.column_with_name(&column.name) {
Some((column_idx, _)) => {
let pos = ColumnPosition {
row_group_idx,
column_idx,
};
is_equal(pos, val, negated)
}
_ => true,
}
};

*should_read = pruner.prune(&f);
}
}

should_reads
.iter()
.enumerate()
.filter_map(|(row_group_idx, should_read)| {
if *should_read {
Some(row_group_idx)
} else {
None
}
})
.collect()
}

/// A pruner based on (not)equal predicates, including in-list predicate.
#[derive(Debug, Clone)]
pub struct EqPruner {
Expand All @@ -22,9 +84,9 @@ impl EqPruner {
/// Use the prune function provided by caller to finish pruning.
///
/// The prune function receives three parameters:
/// - the column
/// - the value of the column
/// - equal or not
/// - the column to compare;
/// - the value of the column used to determine equality;
/// - Whether this compare is negated;
pub fn prune<F>(&self, f: &F) -> bool
where
F: Fn(&Column, &ScalarValue, bool) -> bool,
Expand Down Expand Up @@ -70,8 +132,8 @@ impl NormalizedExpr {
match self {
NormalizedExpr::And { left, right } => left.compute(f) && right.compute(f),
NormalizedExpr::Or { left, right } => left.compute(f) || right.compute(f),
NormalizedExpr::Eq { column, value } => f(column, value, true),
NormalizedExpr::NotEq { column, value } => f(column, value, false),
NormalizedExpr::Eq { column, value } => f(column, value, false),
NormalizedExpr::NotEq { column, value } => f(column, value, true),
NormalizedExpr::True => true,
NormalizedExpr::False => false,
}
Expand Down Expand Up @@ -151,6 +213,10 @@ fn normalize_equal_expr(left: &Expr, right: &Expr, is_equal: bool) -> Normalized

#[cfg(test)]
mod tests {
use std::sync::Arc;

use arrow::datatypes::{DataType, Field, Schema};

use super::*;

fn make_column_expr(name: &str) -> Expr {
Expand Down Expand Up @@ -311,7 +377,7 @@ mod tests {

#[test]
fn test_prune() {
let f = |column: &Column, val: &ScalarValue, equal: bool| -> bool {
let f = |column: &Column, val: &ScalarValue, negated: bool| -> bool {
let val = match val {
ScalarValue::Int32(v) => v.unwrap(),
_ => panic!("Unexpected value type"),
Expand All @@ -323,7 +389,7 @@ mod tests {
"c2" => val == 2,
_ => panic!("Unexpected column"),
};
if !equal {
if negated {
!res
} else {
res
Expand Down Expand Up @@ -361,4 +427,55 @@ mod tests {
);
assert!(!EqPruner::new(&false_expr).prune(&f));
}

#[test]
fn test_filter_row_groups() {
// Provide three row groups (one row in one row group).
// | c0 | c1 | c2 |
// | 0 | 1 | 2 |
// | 1 | 2 | 3 |
// | 2 | 3 | 4 |
let row_groups = vec![vec![0, 1, 2], vec![1, 2, 3], vec![2, 3, 4]];
let is_equal = |pos: ColumnPosition, val: &ScalarValue, negated: bool| -> bool {
let expect_val = row_groups[pos.row_group_idx][pos.column_idx];
let val = if let ScalarValue::Int32(v) = val {
v.expect("Unexpected value")
} else {
panic!("Unexpected value type")
};

if negated {
expect_val != val
} else {
expect_val == val
}
};

// (c0 in [1, 3]) or c1 not in [1, 2]
let predicate1 = Expr::or(
Expr::in_list(
make_column_expr("c0"),
vec![make_literal_expr(1), make_literal_expr(3)],
false,
),
Expr::in_list(
make_column_expr("c1"),
vec![make_literal_expr(1), make_literal_expr(2)],
true,
),
);

// c2 != 2
let predicate2 = Expr::not_eq(make_literal_expr(2), make_column_expr("c2"));

let schema = Schema::new(vec![
Field::new("c0", DataType::Int32, false),
Field::new("c1", DataType::Int32, false),
Field::new("c2", DataType::Int32, false),
]);
let filtered_row_groups =
filter_row_groups(Arc::new(schema), &vec![predicate1, predicate2], 3, is_equal);

assert_eq!(vec![1, 2], filtered_row_groups)
}
}

0 comments on commit ce93fbf

Please sign in to comment.