Skip to content

Commit 8fb2270

Browse files
authored
Improve documentation and add examples for ArrowPredicateFn (#7480)
1 parent 0e48877 commit 8fb2270

File tree

1 file changed

+69
-8
lines changed

1 file changed

+69
-8
lines changed

parquet/src/arrow/arrow_reader/filter.rs

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,18 @@ use arrow_schema::ArrowError;
2121

2222
/// A predicate operating on [`RecordBatch`]
2323
///
24-
/// See [`RowFilter`] for more information on the use of this trait.
24+
/// See also:
25+
/// * [`RowFilter`] for more information on applying filters during the
26+
/// Parquet decoding process.
27+
/// * [`ArrowPredicateFn`] for a concrete implementation based on a function
2528
pub trait ArrowPredicate: Send + 'static {
2629
/// Returns the [`ProjectionMask`] that describes the columns required
27-
/// to evaluate this predicate. All projected columns will be provided in the `batch`
28-
/// passed to [`evaluate`](Self::evaluate)
30+
/// to evaluate this predicate.
31+
///
32+
/// All projected columns will be provided in the `batch` passed to
33+
/// [`evaluate`](Self::evaluate). The projection mask should be as small as
34+
/// possible because any columns needed for the overall projection mask are
35+
/// decoded again after a predicate is applied.
2936
fn projection(&self) -> &ProjectionMask;
3037

3138
/// Evaluate this predicate for the given [`RecordBatch`] containing the columns
@@ -38,7 +45,63 @@ pub trait ArrowPredicate: Send + 'static {
3845
fn evaluate(&mut self, batch: RecordBatch) -> Result<BooleanArray, ArrowError>;
3946
}
4047

41-
/// An [`ArrowPredicate`] created from an [`FnMut`]
48+
/// An [`ArrowPredicate`] created from an [`FnMut`] and a [`ProjectionMask`]
49+
///
50+
/// See [`RowFilter`] for more information on applying filters during the
51+
/// Parquet decoding process.
52+
///
53+
/// The function is passed `RecordBatch`es with only the columns specified in
54+
/// the [`ProjectionMask`].
55+
///
56+
/// The function must return a [`BooleanArray`] that has the same length as the
57+
/// input `batch` where each row indicates whether the row should be returned:
58+
/// * `true`: the row should be returned
59+
/// * `false` or `null`: the row should not be returned
60+
///
61+
/// # Example:
62+
///
63+
/// Given an input schema: `"a:int64", "b:int64"`, you can create a predicate that
64+
/// evaluates `b > 0` like this:
65+
///
66+
/// ```
67+
/// # use std::sync::Arc;
68+
/// # use arrow::compute::kernels::cmp::gt;
69+
/// # use arrow_array::{BooleanArray, Int64Array, RecordBatch};
70+
/// # use arrow_array::cast::AsArray;
71+
/// # use arrow_array::types::Int64Type;
72+
/// # use parquet::arrow::arrow_reader::ArrowPredicateFn;
73+
/// # use parquet::arrow::ProjectionMask;
74+
/// # use parquet::schema::types::{SchemaDescriptor, Type};
75+
/// # use parquet::basic; // note there are two `Type`s that are different
76+
/// # // Schema for a table with one columns: "a" (int64) and "b" (int64)
77+
/// # let descriptor = SchemaDescriptor::new(
78+
/// # Arc::new(
79+
/// # Type::group_type_builder("my_schema")
80+
/// # .with_fields(vec![
81+
/// # Arc::new(
82+
/// # Type::primitive_type_builder("a", basic::Type::INT64)
83+
/// # .build().unwrap()
84+
/// # ),
85+
/// # Arc::new(
86+
/// # Type::primitive_type_builder("b", basic::Type::INT64)
87+
/// # .build().unwrap()
88+
/// # ),
89+
/// # ])
90+
/// # .build().unwrap()
91+
/// # )
92+
/// # );
93+
/// // Create a mask for selecting only the second column "b" (index 1)
94+
/// let projection_mask = ProjectionMask::leaves(&descriptor, [1]);
95+
/// // Closure that evaluates "b > 0"
96+
/// let predicate = |batch: RecordBatch| {
97+
/// let scalar_0 = Int64Array::new_scalar(0);
98+
/// let column = batch.column(0).as_primitive::<Int64Type>();
99+
/// // call the gt kernel to compute `>` which returns a BooleanArray
100+
/// gt(column, &scalar_0)
101+
/// };
102+
/// // Create ArrowPredicateFn that can be passed to RowFilter
103+
/// let arrow_predicate = ArrowPredicateFn::new(projection_mask, predicate);
104+
/// ```
42105
pub struct ArrowPredicateFn<F> {
43106
f: F,
44107
projection: ProjectionMask,
@@ -48,10 +111,8 @@ impl<F> ArrowPredicateFn<F>
48111
where
49112
F: FnMut(RecordBatch) -> Result<BooleanArray, ArrowError> + Send + 'static,
50113
{
51-
/// Create a new [`ArrowPredicateFn`]. `f` will be passed batches
52-
/// that contains the columns specified in `projection`
53-
/// and returns a [`BooleanArray`] that describes which rows should
54-
/// be passed along
114+
/// Create a new [`ArrowPredicateFn`] that invokes `f` on the columns
115+
/// specified in `projection`.
55116
pub fn new(projection: ProjectionMask, f: F) -> Self {
56117
Self { f, projection }
57118
}

0 commit comments

Comments
 (0)