Skip to content

Commit e9a7fe5

Browse files
pepijnvemartin-g
andauthored
Add FilterPredicate::filter_record_batch (#8693)
# Which issue does this PR close? - Closes #8692. # Rationale for this change Explained in issue. # What changes are included in this PR? - Adds `FilterPredicate::filter_record_batch` - Adapts the free function `filter_record_batch` to use the new function - Uses `new_unchecked` to create the filtered result. The rationale for this is identical to #8583 # Are these changes tested? Covered by existing tests for `filter_record_batch` # Are there any user-facing changes? No --------- Co-authored-by: Martin Grigorov <martin-g@users.noreply.github.com>
1 parent 79575aa commit e9a7fe5

File tree

1 file changed

+43
-9
lines changed

1 file changed

+43
-9
lines changed

arrow-select/src/filter.rs

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,12 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
122122
/// Returns a filtered `values` [`Array`] where the corresponding elements of
123123
/// `predicate` are `true`.
124124
///
125+
/// If multiple arrays (or record batches) need to be filtered using the same predicate array,
126+
/// consider using [FilterBuilder] to create a single [FilterPredicate] and then
127+
/// calling [FilterPredicate::filter_record_batch].
128+
/// In contrast to this function, it is then the responsibility of the caller
129+
/// to use [FilterBuilder::optimize] if appropriate.
130+
///
125131
/// # See also
126132
/// * [`FilterBuilder`] for more control over the filtering process.
127133
/// * [`filter_record_batch`] to filter a [`RecordBatch`]
@@ -168,25 +174,28 @@ fn multiple_arrays(data_type: &DataType) -> bool {
168174
/// `predicate` are true.
169175
///
170176
/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
177+
///
178+
/// If multiple record batches (or arrays) need to be filtered using the same predicate array,
179+
/// consider using [FilterBuilder] to create a single [FilterPredicate] and then
180+
/// calling [FilterPredicate::filter_record_batch].
181+
/// In contrast to this function, it is then the responsibility of the caller
182+
/// to use [FilterBuilder::optimize] if appropriate.
171183
pub fn filter_record_batch(
172184
record_batch: &RecordBatch,
173185
predicate: &BooleanArray,
174186
) -> Result<RecordBatch, ArrowError> {
175187
let mut filter_builder = FilterBuilder::new(predicate);
176-
if record_batch.num_columns() > 1 {
177-
// Only optimize if filtering more than one column
188+
let num_cols = record_batch.num_columns();
189+
if num_cols > 1
190+
|| (num_cols > 0 && multiple_arrays(record_batch.schema_ref().field(0).data_type()))
191+
{
192+
// Only optimize if filtering more than one column or if the column contains multiple internal arrays
178193
// Otherwise, the overhead of optimization can be more than the benefit
179194
filter_builder = filter_builder.optimize();
180195
}
181196
let filter = filter_builder.build();
182197

183-
let filtered_arrays = record_batch
184-
.columns()
185-
.iter()
186-
.map(|a| filter_array(a, &filter))
187-
.collect::<Result<Vec<_>, _>>()?;
188-
let options = RecordBatchOptions::default().with_row_count(Some(filter.count()));
189-
RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options)
198+
filter.filter_record_batch(record_batch)
190199
}
191200

192201
/// A builder to construct [`FilterPredicate`]
@@ -300,6 +309,31 @@ impl FilterPredicate {
300309
filter_array(values, self)
301310
}
302311

312+
/// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this
313+
/// [`FilterPredicate`].
314+
///
315+
/// This is the equivalent of calling [filter] on each column of the [`RecordBatch`].
316+
pub fn filter_record_batch(
317+
&self,
318+
record_batch: &RecordBatch,
319+
) -> Result<RecordBatch, ArrowError> {
320+
let filtered_arrays = record_batch
321+
.columns()
322+
.iter()
323+
.map(|a| filter_array(a, self))
324+
.collect::<Result<Vec<_>, _>>()?;
325+
326+
// SAFETY: we know that the set of filtered arrays will match the schema of the original
327+
// record batch
328+
unsafe {
329+
Ok(RecordBatch::new_unchecked(
330+
record_batch.schema(),
331+
filtered_arrays,
332+
self.count,
333+
))
334+
}
335+
}
336+
303337
/// Number of rows being selected based on this [`FilterPredicate`]
304338
pub fn count(&self) -> usize {
305339
self.count

0 commit comments

Comments
 (0)