diff --git a/components/parquet_ext/src/prune/equal.rs b/components/parquet_ext/src/prune/equal.rs index 63d73e68dc..b33bdf2eab 100644 --- a/components/parquet_ext/src/prune/equal.rs +++ b/components/parquet_ext/src/prune/equal.rs @@ -1,10 +1,72 @@ // Copyright 2022 CeresDB Project Authors. Licensed under Apache-2.0. +use arrow::datatypes::SchemaRef; use datafusion::{logical_plan::Column, scalar::ScalarValue}; use datafusion_expr::{Expr, Operator}; const MAX_ELEMS_IN_LIST_FOR_FILTER: usize = 100; +/// A position used to describe the location of a column in the row groups. +#[derive(Debug, Clone, Copy)] +pub struct ColumnPosition { + pub row_group_idx: usize, + pub column_idx: usize, +} + +/// Filter the row groups according to the `exprs`. +/// +/// The return value is the filtered row group indexes. And the `is_equal` +/// closure receive three parameters: +/// - The position of the column in the row groups; +/// - The value of the column used to determine equality; +/// - Whether this compare is negated; +pub fn filter_row_groups( + schema: SchemaRef, + exprs: &[Expr], + row_group_num: usize, + is_equal: E, +) -> Vec +where + E: Fn(ColumnPosition, &ScalarValue, bool) -> bool, +{ + let mut should_reads = vec![true; row_group_num]; + for expr in exprs { + let pruner = EqPruner::new(expr); + for (row_group_idx, should_read) in should_reads.iter_mut().enumerate() { + if !*should_read { + continue; + } + + let f = |column: &Column, val: &ScalarValue, negated: bool| -> bool { + match schema.column_with_name(&column.name) { + Some((column_idx, _)) => { + let pos = ColumnPosition { + row_group_idx, + column_idx, + }; + is_equal(pos, val, negated) + } + _ => true, + } + }; + + *should_read = pruner.prune(&f); + } + } + + should_reads + .iter() + .enumerate() + .filter_map(|(row_group_idx, should_read)| { + if *should_read { + Some(row_group_idx) + } else { + None + } + }) + .collect() +} + /// A pruner based on (not)equal predicates, including in-list predicate. #[derive(Debug, Clone)] pub struct EqPruner { @@ -22,9 +84,9 @@ impl EqPruner { /// Use the prune function provided by caller to finish pruning. /// /// The prune function receives three parameters: - /// - the column - /// - the value of the column - /// - equal or not + /// - the column to compare; + /// - the value of the column used to determine equality; + /// - Whether this compare is negated; pub fn prune(&self, f: &F) -> bool where F: Fn(&Column, &ScalarValue, bool) -> bool, @@ -70,8 +132,8 @@ impl NormalizedExpr { match self { NormalizedExpr::And { left, right } => left.compute(f) && right.compute(f), NormalizedExpr::Or { left, right } => left.compute(f) || right.compute(f), - NormalizedExpr::Eq { column, value } => f(column, value, true), - NormalizedExpr::NotEq { column, value } => f(column, value, false), + NormalizedExpr::Eq { column, value } => f(column, value, false), + NormalizedExpr::NotEq { column, value } => f(column, value, true), NormalizedExpr::True => true, NormalizedExpr::False => false, } @@ -151,6 +213,10 @@ fn normalize_equal_expr(left: &Expr, right: &Expr, is_equal: bool) -> Normalized #[cfg(test)] mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use super::*; fn make_column_expr(name: &str) -> Expr { @@ -311,7 +377,7 @@ mod tests { #[test] fn test_prune() { - let f = |column: &Column, val: &ScalarValue, equal: bool| -> bool { + let f = |column: &Column, val: &ScalarValue, negated: bool| -> bool { let val = match val { ScalarValue::Int32(v) => v.unwrap(), _ => panic!("Unexpected value type"), @@ -323,7 +389,7 @@ mod tests { "c2" => val == 2, _ => panic!("Unexpected column"), }; - if !equal { + if negated { !res } else { res @@ -361,4 +427,55 @@ mod tests { ); assert!(!EqPruner::new(&false_expr).prune(&f)); } + + #[test] + fn test_filter_row_groups() { + // Provide three row groups (one row in one row group). + // | c0 | c1 | c2 | + // | 0 | 1 | 2 | + // | 1 | 2 | 3 | + // | 2 | 3 | 4 | + let row_groups = vec![vec![0, 1, 2], vec![1, 2, 3], vec![2, 3, 4]]; + let is_equal = |pos: ColumnPosition, val: &ScalarValue, negated: bool| -> bool { + let expect_val = row_groups[pos.row_group_idx][pos.column_idx]; + let val = if let ScalarValue::Int32(v) = val { + v.expect("Unexpected value") + } else { + panic!("Unexpected value type") + }; + + if negated { + expect_val != val + } else { + expect_val == val + } + }; + + // (c0 in [1, 3]) or c1 not in [1, 2] + let predicate1 = Expr::or( + Expr::in_list( + make_column_expr("c0"), + vec![make_literal_expr(1), make_literal_expr(3)], + false, + ), + Expr::in_list( + make_column_expr("c1"), + vec![make_literal_expr(1), make_literal_expr(2)], + true, + ), + ); + + // c2 != 2 + let predicate2 = Expr::not_eq(make_literal_expr(2), make_column_expr("c2")); + + let schema = Schema::new(vec![ + Field::new("c0", DataType::Int32, false), + Field::new("c1", DataType::Int32, false), + Field::new("c2", DataType::Int32, false), + ]); + let filtered_row_groups = + filter_row_groups(Arc::new(schema), &vec![predicate1, predicate2], 3, is_equal); + + assert_eq!(vec![1, 2], filtered_row_groups) + } }