@@ -122,6 +122,12 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
122122/// Returns a filtered `values` [`Array`] where the corresponding elements of
123123/// `predicate` are `true`.
124124///
125+ /// If multiple arrays (or record batches) need to be filtered using the same predicate array,
126+ /// consider using [FilterBuilder] to create a single [FilterPredicate] and then
127+ /// calling [FilterPredicate::filter_record_batch].
128+ /// In contrast to this function, it is then the responsibility of the caller
129+ /// to use [FilterBuilder::optimize] if appropriate.
130+ ///
125131/// # See also
126132/// * [`FilterBuilder`] for more control over the filtering process.
127133/// * [`filter_record_batch`] to filter a [`RecordBatch`]
@@ -168,25 +174,28 @@ fn multiple_arrays(data_type: &DataType) -> bool {
168174/// `predicate` are true.
169175///
170176/// This is the equivalent of calling [filter] on each column of the [RecordBatch].
177+ ///
178+ /// If multiple record batches (or arrays) need to be filtered using the same predicate array,
179+ /// consider using [FilterBuilder] to create a single [FilterPredicate] and then
180+ /// calling [FilterPredicate::filter_record_batch].
181+ /// In contrast to this function, it is then the responsibility of the caller
182+ /// to use [FilterBuilder::optimize] if appropriate.
171183pub fn filter_record_batch (
172184 record_batch : & RecordBatch ,
173185 predicate : & BooleanArray ,
174186) -> Result < RecordBatch , ArrowError > {
175187 let mut filter_builder = FilterBuilder :: new ( predicate) ;
176- if record_batch. num_columns ( ) > 1 {
177- // Only optimize if filtering more than one column
188+ let num_cols = record_batch. num_columns ( ) ;
189+ if num_cols > 1
190+ || ( num_cols > 0 && multiple_arrays ( record_batch. schema_ref ( ) . field ( 0 ) . data_type ( ) ) )
191+ {
192+ // Only optimize if filtering more than one column or if the column contains multiple internal arrays
178193 // Otherwise, the overhead of optimization can be more than the benefit
179194 filter_builder = filter_builder. optimize ( ) ;
180195 }
181196 let filter = filter_builder. build ( ) ;
182197
183- let filtered_arrays = record_batch
184- . columns ( )
185- . iter ( )
186- . map ( |a| filter_array ( a, & filter) )
187- . collect :: < Result < Vec < _ > , _ > > ( ) ?;
188- let options = RecordBatchOptions :: default ( ) . with_row_count ( Some ( filter. count ( ) ) ) ;
189- RecordBatch :: try_new_with_options ( record_batch. schema ( ) , filtered_arrays, & options)
198+ filter. filter_record_batch ( record_batch)
190199}
191200
192201/// A builder to construct [`FilterPredicate`]
@@ -300,6 +309,31 @@ impl FilterPredicate {
300309 filter_array ( values, self )
301310 }
302311
312+ /// Returns a filtered [`RecordBatch`] containing only the rows that are selected by this
313+ /// [`FilterPredicate`].
314+ ///
315+ /// This is the equivalent of calling [filter] on each column of the [`RecordBatch`].
316+ pub fn filter_record_batch (
317+ & self ,
318+ record_batch : & RecordBatch ,
319+ ) -> Result < RecordBatch , ArrowError > {
320+ let filtered_arrays = record_batch
321+ . columns ( )
322+ . iter ( )
323+ . map ( |a| filter_array ( a, self ) )
324+ . collect :: < Result < Vec < _ > , _ > > ( ) ?;
325+
326+ // SAFETY: we know that the set of filtered arrays will match the schema of the original
327+ // record batch
328+ unsafe {
329+ Ok ( RecordBatch :: new_unchecked (
330+ record_batch. schema ( ) ,
331+ filtered_arrays,
332+ self . count ,
333+ ) )
334+ }
335+ }
336+
303337 /// Number of rows being selected based on this [`FilterPredicate`]
304338 pub fn count ( & self ) -> usize {
305339 self . count
0 commit comments