diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index dace2bab728f..5c21a4adcab7 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -122,6 +122,12 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray { /// Returns a filtered `values` [`Array`] where the corresponding elements of /// `predicate` are `true`. /// +/// If multiple arrays (or record batches) need to be filtered using the same predicate array, +/// consider using [FilterBuilder] to create a single [FilterPredicate] and then +/// calling [FilterPredicate::filter_record_batch]. +/// In contrast to this function, it is then the responsibility of the caller +/// to use [FilterBuilder::optimize] if appropriate. +/// /// # See also /// * [`FilterBuilder`] for more control over the filtering process. /// * [`filter_record_batch`] to filter a [`RecordBatch`] @@ -168,25 +174,28 @@ fn multiple_arrays(data_type: &DataType) -> bool { /// `predicate` are true. /// /// This is the equivalent of calling [filter] on each column of the [RecordBatch]. +/// +/// If multiple record batches (or arrays) need to be filtered using the same predicate array, +/// consider using [FilterBuilder] to create a single [FilterPredicate] and then +/// calling [FilterPredicate::filter_record_batch]. +/// In contrast to this function, it is then the responsibility of the caller +/// to use [FilterBuilder::optimize] if appropriate. pub fn filter_record_batch( record_batch: &RecordBatch, predicate: &BooleanArray, ) -> Result { let mut filter_builder = FilterBuilder::new(predicate); - if record_batch.num_columns() > 1 { - // Only optimize if filtering more than one column + let num_cols = record_batch.num_columns(); + if num_cols > 1 + || (num_cols > 0 && multiple_arrays(record_batch.schema_ref().field(0).data_type())) + { + // Only optimize if filtering more than one column or if the column contains multiple internal arrays // Otherwise, the overhead of optimization can be more than the benefit filter_builder = filter_builder.optimize(); } let filter = filter_builder.build(); - let filtered_arrays = record_batch - .columns() - .iter() - .map(|a| filter_array(a, &filter)) - .collect::, _>>()?; - let options = RecordBatchOptions::default().with_row_count(Some(filter.count())); - RecordBatch::try_new_with_options(record_batch.schema(), filtered_arrays, &options) + filter.filter_record_batch(record_batch) } /// A builder to construct [`FilterPredicate`] @@ -300,6 +309,31 @@ impl FilterPredicate { filter_array(values, self) } + /// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this + /// [`FilterPredicate`]. + /// + /// This is the equivalent of calling [filter] on each column of the [`RecordBatch`]. + pub fn filter_record_batch( + &self, + record_batch: &RecordBatch, + ) -> Result { + let filtered_arrays = record_batch + .columns() + .iter() + .map(|a| filter_array(a, self)) + .collect::, _>>()?; + + // SAFETY: we know that the set of filtered arrays will match the schema of the original + // record batch + unsafe { + Ok(RecordBatch::new_unchecked( + record_batch.schema(), + filtered_arrays, + self.count, + )) + } + } + /// Number of rows being selected based on this [`FilterPredicate`] pub fn count(&self) -> usize { self.count