diff --git a/arrow-buffer/src/builder/null.rs b/arrow-buffer/src/builder/null.rs index 2ffd4dcd4c35..db94aa9fe6cd 100644 --- a/arrow-buffer/src/builder/null.rs +++ b/arrow-buffer/src/builder/null.rs @@ -196,6 +196,33 @@ impl NullBufferBuilder { } } + /// Extends this builder with validity values. + /// + /// # Safety + /// The caller must ensure that the iterator reports the correct length. + /// + /// # Example + /// ``` + /// # use arrow_buffer::NullBufferBuilder; + /// let mut builder = NullBufferBuilder::new(8); + /// let validities = [true, false, true, true]; + /// unsafe { builder.extend_trusted_len(validities.iter().copied()); } + /// assert_eq!(builder.len(), 4); + /// ``` + pub unsafe fn extend_trusted_len>(&mut self, iter: I) { + // Materialize since we're about to append bits + self.materialize_if_needed(); + + unsafe { + self.bitmap_builder + .as_mut() + .unwrap() + .extend_trusted_len(iter) + }; + } + + /// Builds the null buffer and resets the builder. + /// Returns `None` if the builder only contains `true`s. /// Builds the [`NullBuffer`] and resets the builder. /// /// Returns `None` if the builder only contains `true`s. Use [`Self::build`] @@ -412,4 +439,47 @@ mod tests { assert_eq!(builder.finish(), None); } + + #[test] + fn test_extend() { + // Test small extend (less than 64 bits) + let mut builder = NullBufferBuilder::new(0); + unsafe { + builder.extend_trusted_len([true, false, true, true].iter().copied()); + } + // bits: 0=true, 1=false, 2=true, 3=true -> 0b1101 = 13 + assert_eq!(builder.as_slice().unwrap(), &[0b1101_u8]); + + // Test extend with exactly 64 bits + let mut builder = NullBufferBuilder::new(0); + let pattern: Vec = (0..64).map(|i| i % 2 == 0).collect(); + unsafe { + builder.extend_trusted_len(pattern.iter().copied()); + } + // Even positions are true: 0, 2, 4, ... -> bits 0, 2, 4, ... + // In little-endian: 0b01010101 repeated + assert_eq!( + builder.as_slice().unwrap(), + &[0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55, 0x55] + ); + + // Test extend with more than 64 bits (tests chunking) + let mut builder = NullBufferBuilder::new(0); + let pattern: Vec = (0..100).map(|i| i % 3 == 0).collect(); + unsafe { builder.extend_trusted_len(pattern.iter().copied()) }; + assert_eq!(builder.len(), 100); + // Verify a few specific bits + let buf = builder.finish().unwrap(); + assert!(buf.is_valid(0)); // 0 % 3 == 0 + assert!(!buf.is_valid(1)); // 1 % 3 != 0 + assert!(!buf.is_valid(2)); // 2 % 3 != 0 + assert!(buf.is_valid(3)); // 3 % 3 == 0 + assert!(buf.is_valid(99)); // 99 % 3 == 0 + + // Test extend with non-aligned start (tests bit-by-bit path) + let mut builder = NullBufferBuilder::new(0); + builder.append_non_null(); // Start at bit 1 (non-aligned) + unsafe { builder.extend_trusted_len([false, true, false, true].iter().copied()) }; + assert_eq!(builder.as_slice().unwrap(), &[0b10101_u8]); + } } diff --git a/arrow-select/src/coalesce.rs b/arrow-select/src/coalesce.rs index 5ea2d97e78ea..fd92ce8bc140 100644 --- a/arrow-select/src/coalesce.rs +++ b/arrow-select/src/coalesce.rs @@ -20,7 +20,8 @@ //! //! [`filter`]: crate::filter::filter //! [`take`]: crate::take::take -use crate::filter::filter_record_batch; +use crate::filter::{FilterBuilder, FilterPredicate, is_optimize_beneficial_record_batch}; + use crate::take::take_record_batch; use arrow_array::types::{BinaryViewType, StringViewType}; use arrow_array::{Array, ArrayRef, BooleanArray, RecordBatch, downcast_primitive}; @@ -212,7 +213,10 @@ impl BatchCoalescer { /// Push a batch into the Coalescer after applying a filter /// /// This is semantically equivalent of calling [`Self::push_batch`] - /// with the results from [`filter_record_batch`] + /// with the results from [`filter_record_batch`], but avoids + /// materializing the intermediate filtered batch. + /// + /// [`filter_record_batch`]: crate::filter::filter_record_batch /// /// # Example /// ``` @@ -238,10 +242,103 @@ impl BatchCoalescer { batch: RecordBatch, filter: &BooleanArray, ) -> Result<(), ArrowError> { - // TODO: optimize this to avoid materializing (copying the results - // of filter to a new batch) - let filtered_batch = filter_record_batch(&batch, filter)?; - self.push_batch(filtered_batch) + // We only support primitve now, fallback to filter_record_batch for other types + // Also, skip optimization when filter is not very selectivex§ + + // Build an optimized filter predicate that chooses the best iteration strategy + // Byteview does use a filter as part of calculating ideal buffer sizes, so optimizing is helpful even for + // a single array + let is_optimize_beneficial = is_optimize_beneficial_record_batch(&batch) + || batch.columns().len() == 1 + && matches!( + batch.columns()[0].data_type(), + DataType::BinaryView | DataType::Utf8View + ); + let selected_count = filter.true_count(); + let num_rows = batch.num_rows(); + + // Fast path: skip if no rows selected + if selected_count == 0 { + return Ok(()); + } + + // Fast path: if all rows selected, just push the batch + if selected_count == num_rows { + return self.push_batch(batch); + } + + let (_schema, arrays, _num_rows) = batch.into_parts(); + + let mut filter_builder = FilterBuilder::new(&filter); + + if is_optimize_beneficial { + filter_builder = filter_builder.optimize(); + } + + let filter = filter_builder.build(); + // Setup input arrays as sources + assert_eq!(arrays.len(), self.in_progress_arrays.len()); + self.in_progress_arrays + .iter_mut() + .zip(arrays) + .for_each(|(in_progress, array)| { + in_progress.set_source_from_filter(Some(array), &filter); + }); + + // Choose iteration strategy based on the optimized predicate + self.copy_from_filter(filter, selected_count)?; + // Clear sources to allow memory to be freed + for in_progress in self.in_progress_arrays.iter_mut() { + in_progress.set_source(None); + } + + Ok(()) + } + + /// Helper to copy rows at the given indices, handling batch boundaries efficiently + /// + /// This method batches the index iteration to avoid per-row batch boundary checks. + fn copy_from_filter( + &mut self, + filter: FilterPredicate, + count: usize, + ) -> Result<(), ArrowError> { + let mut remaining = count; + let mut filter_pos = 0; // Position in the filter array + + // Build an optimized filter predicate once for the whole input batch + + // We need to process the filter in chunks that fit the target batch size + while remaining > 0 { + let space_in_batch = self.target_batch_size - self.buffered_rows; + let to_copy = remaining.min(space_in_batch); + + // Find how many filter positions we need to cover `to_copy` set bits + // Skip the expensive search if all remaining rows fit in the current batch + let chunk_len = if remaining <= space_in_batch { + filter.len() - filter_pos + } else { + filter.find_nth_set_bit_position(filter_pos, to_copy) - filter_pos + }; + + let chunk_predicate = filter.slice_with_count(filter_pos, chunk_len, to_copy); + + // Copy all collected indices in one call per array + for in_progress in self.in_progress_arrays.iter_mut() { + in_progress.copy_rows_by_filter(&chunk_predicate, filter_pos, chunk_len)?; + } + + self.buffered_rows += to_copy; + filter_pos += chunk_len; + remaining -= to_copy; + + // If we've filled the batch, finish it + if self.buffered_rows >= self.target_batch_size { + self.finish_buffered_batch()?; + } + } + + Ok(()) } /// Push a batch into the Coalescer after applying a set of indices @@ -598,6 +695,13 @@ trait InProgressArray: std::fmt::Debug + Send + Sync { /// current in-progress array fn set_source(&mut self, source: Option); + /// Set the source array with a filter, allowing for calculating GC based on filter + /// + /// Default implementation just calls [`Self::set_source`] + fn set_source_from_filter(&mut self, source: Option, _filter: &FilterPredicate) { + self.set_source(source); + } + /// Copy rows from the current source array into the in-progress array /// /// The source array is set by [`Self::set_source`]. @@ -605,6 +709,17 @@ trait InProgressArray: std::fmt::Debug + Send + Sync { /// Return an error if the source array is not set fn copy_rows(&mut self, offset: usize, len: usize) -> Result<(), ArrowError>; + /// Copy rows from the source array between the specified offset and len that + /// match the predicate to the output array + /// + /// TODO add an example + fn copy_rows_by_filter( + &mut self, + filter: &FilterPredicate, + offset: usize, + len: usize, + ) -> Result<(), ArrowError>; + /// Finish the currently in-progress array and return it as an `ArrayRef` fn finish(&mut self) -> Result; } @@ -613,6 +728,7 @@ trait InProgressArray: std::fmt::Debug + Send + Sync { mod tests { use super::*; use crate::concat::concat_batches; + use crate::filter::filter_record_batch; use arrow_array::builder::StringViewBuilder; use arrow_array::cast::AsArray; use arrow_array::types::Int32Type; diff --git a/arrow-select/src/coalesce/byte_view.rs b/arrow-select/src/coalesce/byte_view.rs index bca811fff1c6..a5b3946d9170 100644 --- a/arrow-select/src/coalesce/byte_view.rs +++ b/arrow-select/src/coalesce/byte_view.rs @@ -16,10 +16,11 @@ // under the License. use crate::coalesce::InProgressArray; +use crate::filter::{FilterPredicate, IndexIterator, IterationStrategy, SlicesIterator}; use arrow_array::cast::AsArray; use arrow_array::types::ByteViewType; -use arrow_array::{Array, ArrayRef, GenericByteViewArray}; -use arrow_buffer::{Buffer, NullBufferBuilder}; +use arrow_array::{Array, ArrayRef, BooleanArray, GenericByteViewArray}; +use arrow_buffer::{Buffer, NullBuffer, NullBufferBuilder}; use arrow_data::{ByteView, MAX_INLINE_VIEW_LEN}; use arrow_schema::ArrowError; use std::marker::PhantomData; @@ -276,6 +277,57 @@ impl InProgressByteViewArray { self.views.extend(new_views); self.current = Some(dst_buffer); } + + /// Translate a single view while GCing, updating `current` buffer as necessary + #[inline] + fn translate_view_gc( + v: u128, + buffers: &[Buffer], + current: &mut Vec, + completed: &mut Vec, + buffer_source: &mut BufferSource, + ideal_buffer_size: usize, + ) -> u128 { + if (v as u32) <= MAX_INLINE_VIEW_LEN { + return v; + } + + let mut b = ByteView::from(v); + let str_len = b.length as usize; + if current.len() + str_len > current.capacity() { + let next = buffer_source.next_buffer(ideal_buffer_size); + let prev = std::mem::replace(current, next); + completed.push(prev.into()); + } + + let old_idx = b.buffer_index as usize; + let old_offset = b.offset as usize; + b.offset = current.len() as u32; + b.buffer_index = completed.len() as u32; + + let src = unsafe { + // Safety: inputs are validly constructed + buffers + .get_unchecked(old_idx) + .get_unchecked(old_offset..old_offset + str_len) + }; + current.extend_from_slice(src); + b.as_u128() + } + + /// Update views and push to `self.views` + #[inline] + fn append_views_with_offset(&mut self, views: &[u128], offset: u32) { + let updated_views = views.iter().map(|v| { + let mut byte_view = ByteView::from(*v); + if byte_view.length > MAX_INLINE_VIEW_LEN { + byte_view.buffer_index += offset; + }; + byte_view.as_u128() + }); + + self.views.extend(updated_views); + } } impl InProgressArray for InProgressByteViewArray { @@ -286,11 +338,12 @@ impl InProgressArray for InProgressByteViewArray { let (need_gc, ideal_buffer_size) = if s.data_buffers().is_empty() { (false, 0) } else { + let actual_buffer_size = + s.data_buffers().iter().map(|b| b.capacity()).sum::(); + let ideal_buffer_size = s.total_buffer_bytes_used(); // We don't use get_buffer_memory_size here, because gc is for the contents of the // data buffers, not views and nulls. - let actual_buffer_size = - s.data_buffers().iter().map(|b| b.capacity()).sum::(); // copying strings is expensive, so only do it if the array is // sparse (uses at least 2x the memory it needs) let need_gc = @@ -306,6 +359,85 @@ impl InProgressArray for InProgressByteViewArray { }) } + fn set_source_from_filter(&mut self, source: Option, filter: &FilterPredicate) { + self.source = source.map(|array| { + let s = array.as_byte_view::(); + if s.data_buffers().is_empty() { + return Source { + array, + need_gc: false, + ideal_buffer_size: 0, + }; + } + + let views = s.views().as_ref(); + let mut ideal_buffer_size = 0; + + match filter.strategy() { + IterationStrategy::None => { + return Source { + array, + need_gc: false, + ideal_buffer_size: 0, + }; + } + IterationStrategy::All => { + // all rows selected + ideal_buffer_size = s.total_buffer_bytes_used(); + } + IterationStrategy::IndexIterator => { + for idx in IndexIterator::new(filter.filter_array(), filter.count()) { + let v = unsafe { *views.get_unchecked(idx) }; + let len: u32 = ByteView::from(v).length; + if len > MAX_INLINE_VIEW_LEN { + ideal_buffer_size += len as usize; + } + } + } + IterationStrategy::SlicesIterator => { + for (start, end) in SlicesIterator::new(filter.filter_array()) { + for v in &views[start..end] { + let len: u32 = ByteView::from(*v).length; + if len > MAX_INLINE_VIEW_LEN { + ideal_buffer_size += len as usize; + } + } + } + } + IterationStrategy::Indices(indices) => { + for &idx in indices { + let v = unsafe { *views.get_unchecked(idx) }; + let len: u32 = ByteView::from(v).length; + if len > MAX_INLINE_VIEW_LEN { + ideal_buffer_size += len as usize; + } + } + } + IterationStrategy::Slices(slices) => { + for (start, end) in slices { + for v in &views[*start..*end] { + let len: u32 = ByteView::from(*v).length; + if len > MAX_INLINE_VIEW_LEN { + ideal_buffer_size += len as usize; + } + } + } + } + } + + let actual_buffer_size = s.data_buffers().iter().map(|b| b.capacity()).sum::(); + + let need_gc = ideal_buffer_size != 0 && actual_buffer_size > (ideal_buffer_size * 2); + let fraction = filter.count() as f64 / s.len() as f64; + let ideal_buffer_size = (ideal_buffer_size as f64 * (1.0 / fraction)).round() as usize; + Source { + array, + need_gc, + ideal_buffer_size, + } + }); + } + fn copy_rows(&mut self, offset: usize, len: usize) -> Result<(), ArrowError> { self.ensure_capacity(); let source = self.source.take().ok_or_else(|| { @@ -347,6 +479,235 @@ impl InProgressArray for InProgressByteViewArray { Ok(()) } + fn copy_rows_by_filter( + &mut self, + filter: &FilterPredicate, + offset: usize, + len: usize, + ) -> Result<(), ArrowError> { + match filter.strategy() { + IterationStrategy::None => return Ok(()), + IterationStrategy::All => return self.copy_rows(offset, len), + _ => {} + } + + self.ensure_capacity(); + let source = self.source.take().ok_or_else(|| { + ArrowError::InvalidArgumentError( + "Internal Error: InProgressByteViewArray: source not set".to_string(), + ) + })?; + + let s = source.array.as_byte_view::(); + let views = &s.views().as_ref()[offset..offset + len]; + let buffers = s.data_buffers(); + + // Handle nulls + if let Some(nulls) = s.nulls() { + let null_array = BooleanArray::new(nulls.inner().clone(), None).slice(offset, len); + let filtered_nulls = filter.filter(&null_array)?; + let filtered_nulls = filtered_nulls.as_boolean(); + self.nulls + .append_buffer(&NullBuffer::new(filtered_nulls.values().clone())); + } else { + self.nulls.append_n_non_nulls(filter.count()); + } + + if source.ideal_buffer_size == 0 { + match filter.strategy() { + IterationStrategy::None | IterationStrategy::All => unreachable!(), + IterationStrategy::SlicesIterator => { + for (start, end) in SlicesIterator::new(filter.filter_array()) { + // Safety: filter created valid indices + self.views + .extend_from_slice(unsafe { views.get_unchecked(start..end) }); + } + } + IterationStrategy::IndexIterator => { + self.views.extend( + IndexIterator::new(filter.filter_array(), filter.count()) + .map(|idx| unsafe { *views.get_unchecked(idx) }), + ); + } + IterationStrategy::Indices(indices) => { + self.views.extend( + indices + .iter() + .map(|&idx| unsafe { *views.get_unchecked(idx) }), + ); + } + IterationStrategy::Slices(slices) => { + for (start, end) in slices { + // Safety: filter created valid indices + self.views + .extend_from_slice(unsafe { views.get_unchecked(*start..*end) }); + } + } + } + } else if source.need_gc { + let ideal_buffer_size = source.ideal_buffer_size; + let mut current = self + .current + .take() + .unwrap_or_else(|| self.buffer_source.next_buffer(ideal_buffer_size)); + + match filter.strategy() { + IterationStrategy::None | IterationStrategy::All => unreachable!(), + IterationStrategy::SlicesIterator => { + for (start, end) in SlicesIterator::new(filter.filter_array()) { + // Safety: filter created valid indices + let slice = unsafe { views.get_unchecked(start..end) }; + self.views.extend(slice.iter().map(|&v| { + Self::translate_view_gc( + v, + buffers, + &mut current, + &mut self.completed, + &mut self.buffer_source, + ideal_buffer_size, + ) + })); + } + } + IterationStrategy::IndexIterator => { + self.views.extend( + IndexIterator::new(filter.filter_array(), filter.count()).map(|idx| { + // Safety: filter created valid indices + let v = unsafe { *views.get_unchecked(idx) }; + Self::translate_view_gc( + v, + buffers, + &mut current, + &mut self.completed, + &mut self.buffer_source, + ideal_buffer_size, + ) + }), + ); + } + IterationStrategy::Indices(indices) => { + self.views.extend(indices.iter().map(|&idx| { + // Safety: filter created valid indices + let v = unsafe { *views.get_unchecked(idx) }; + Self::translate_view_gc( + v, + buffers, + &mut current, + &mut self.completed, + &mut self.buffer_source, + ideal_buffer_size, + ) + })); + } + IterationStrategy::Slices(slices) => { + for (start, end) in slices { + // Safety: filter created valid indices + let slice = unsafe { views.get_unchecked(*start..*end) }; + + self.views.extend(slice.iter().map(|&v| { + Self::translate_view_gc( + v, + buffers, + &mut current, + &mut self.completed, + &mut self.buffer_source, + ideal_buffer_size, + ) + })); + } + } + } + self.current = Some(current); + } else { + if let Some(buffer) = self.current.take() { + self.completed.push(buffer.into()); + } + let starting_buffer: u32 = self.completed.len().try_into().expect("too many buffers"); + self.completed.extend_from_slice(buffers); + + if starting_buffer == 0 { + match filter.strategy() { + IterationStrategy::None | IterationStrategy::All => unreachable!(), + IterationStrategy::SlicesIterator => { + for (start, end) in SlicesIterator::new(filter.filter_array()) { + // Safety: filter created valid indices + self.views + .extend_from_slice(unsafe { views.get_unchecked(start..end) }); + } + } + IterationStrategy::IndexIterator => { + self.views.extend( + IndexIterator::new(filter.filter_array(), filter.count()) + .map(|idx| unsafe { *views.get_unchecked(idx) }), + ); + } + IterationStrategy::Indices(indices) => { + self.views.extend( + indices + .iter() + .map(|&idx| unsafe { *views.get_unchecked(idx) }), + ); + } + IterationStrategy::Slices(slices) => { + for (start, end) in slices { + // Safety: filter created valid indices + self.views + .extend_from_slice(unsafe { views.get_unchecked(*start..*end) }); + } + } + } + return Ok(()); + } + + match filter.strategy() { + IterationStrategy::None | IterationStrategy::All => unreachable!(), + IterationStrategy::SlicesIterator => { + for (start, end) in SlicesIterator::new(filter.filter_array()) { + // Safety: filter created valid indices + self.append_views_with_offset( + unsafe { views.get_unchecked(start..end) }, + starting_buffer, + ); + } + } + IterationStrategy::IndexIterator => { + self.views.extend( + IndexIterator::new(filter.filter_array(), filter.count()).map(|idx| { + // Safety: filter created valid indices + let mut byte_view: ByteView = + ByteView::from(unsafe { *views.get_unchecked(idx) }); + if byte_view.length > MAX_INLINE_VIEW_LEN { + byte_view.buffer_index += starting_buffer; + }; + byte_view.as_u128() + }), + ); + } + IterationStrategy::Indices(indices) => { + self.views.extend(indices.iter().map(|&idx| { + // Safety: filter created valid indices + let mut byte_view: ByteView = + ByteView::from(unsafe { *views.get_unchecked(idx) }); + if byte_view.length > MAX_INLINE_VIEW_LEN { + byte_view.buffer_index += starting_buffer; + }; + byte_view.as_u128() + })); + } + IterationStrategy::Slices(slices) => { + for (start, end) in slices { + self.append_views_with_offset( + unsafe { views.get_unchecked(*start..*end) }, + starting_buffer, + ); + } + } + } + } + self.source = Some(source); + Ok(()) + } + fn finish(&mut self) -> Result { self.finish_current(); assert!(self.current.is_none()); diff --git a/arrow-select/src/coalesce/generic.rs b/arrow-select/src/coalesce/generic.rs index 1ea57dff929c..9150a738a381 100644 --- a/arrow-select/src/coalesce/generic.rs +++ b/arrow-select/src/coalesce/generic.rs @@ -17,7 +17,8 @@ use super::InProgressArray; use crate::concat::concat; -use arrow_array::ArrayRef; +use crate::filter::FilterPredicate; +use arrow_array::{Array, ArrayRef}; use arrow_schema::ArrowError; /// Generic implementation for [`InProgressArray`] that works with any type of @@ -60,6 +61,22 @@ impl InProgressArray for GenericInProgressArray { Ok(()) } + fn copy_rows_by_filter( + &mut self, + filter: &FilterPredicate, + offset: usize, + len: usize, + ) -> Result<(), ArrowError> { + let Some(source) = self.source.as_mut() else { + return Err(ArrowError::InvalidArgumentError( + "Internal Error: GenericInProgressArray: source not set".to_string(), + )); + }; + let filtered_array = filter.filter(&source.slice(offset, len))?; + self.buffered_arrays.push(filtered_array); + Ok(()) + } + fn finish(&mut self) -> Result { // Concatenate all buffered arrays into a single array, which uses 2x // peak memory diff --git a/arrow-select/src/coalesce/primitive.rs b/arrow-select/src/coalesce/primitive.rs index 69dad221bd52..4182a21173fa 100644 --- a/arrow-select/src/coalesce/primitive.rs +++ b/arrow-select/src/coalesce/primitive.rs @@ -16,9 +16,10 @@ // under the License. use crate::coalesce::InProgressArray; +use crate::filter::{FilterPredicate, IndexIterator, IterationStrategy, SlicesIterator}; use arrow_array::cast::AsArray; use arrow_array::{Array, ArrayRef, ArrowPrimitiveType, PrimitiveArray}; -use arrow_buffer::{NullBufferBuilder, ScalarBuffer}; +use arrow_buffer::{NullBufferBuilder, ScalarBuffer, bit_util}; use arrow_schema::{ArrowError, DataType}; use std::fmt::Debug; use std::sync::Arc; @@ -95,6 +96,151 @@ impl InProgressArray for InProgressPrimitiveArray Ok(()) } + /// Copy rows using a predicate + fn copy_rows_by_filter( + &mut self, + filter: &FilterPredicate, + offset: usize, + len: usize, + ) -> Result<(), ArrowError> { + self.ensure_capacity(); + + let s = self + .source + .as_ref() + .ok_or_else(|| { + ArrowError::InvalidArgumentError( + "Internal Error: InProgressPrimitiveArray: source not set".to_string(), + ) + })? + .slice(offset, len); + let s = s.as_primitive::(); + + let values = s.values(); + let count = filter.count(); + + // Use the predicate's strategy for optimal iteration + match filter.strategy() { + IterationStrategy::SlicesIterator => { + // Copy values, nulls using slices + if let Some(nulls) = s.nulls().filter(|n| n.null_count() > 0) { + for (start, end) in SlicesIterator::new(filter.filter_array()) { + // SAFETY: slices are derived from filter predicate + self.current + .extend_from_slice(unsafe { values.get_unchecked(start..end) }); + let slice = nulls.slice(start, end - start); + self.nulls.append_buffer(&slice); + } + } else { + for (start, end) in SlicesIterator::new(filter.filter_array()) { + // SAFETY: SlicesIterator produces valid ranges derived from filter + self.current + .extend_from_slice(unsafe { values.get_unchecked(start..end) }); + } + self.nulls.append_n_non_nulls(count); + } + } + IterationStrategy::Slices(slices) => { + // Copy values and nulls using precomputed slices - single iteration + if let Some(nulls) = s.nulls().filter(|n| n.null_count() > 0) { + for &(start, end) in slices { + // SAFETY: slices are derived from filter predicate + self.current + .extend_from_slice(unsafe { values.get_unchecked(start..end) }); + let slice = nulls.slice(start, end - start); + self.nulls.append_buffer(&slice); + } + } else { + for &(start, end) in slices { + // SAFETY: slices are derived from filter predicate + self.current + .extend_from_slice(unsafe { values.get_unchecked(start..end) }); + } + self.nulls.append_n_non_nulls(count); + } + } + IterationStrategy::IndexIterator => { + // Copy values and nulls for each index + if let Some(nulls) = s.nulls().filter(|n| n.null_count() > 0) { + let null_buffer = nulls.inner(); + let null_ptr = null_buffer.values().as_ptr(); + let null_offset = null_buffer.offset(); + + // Collect indices for reuse (values + nulls) + let indices = IndexIterator::new(filter.filter_array(), count); + + // Efficiently extend null buffer + // SAFETY: indices iterator reports correct length + unsafe { + self.nulls.extend_trusted_len( + indices.map(|idx| bit_util::get_bit_raw(null_ptr, idx + null_offset)), + ); + } + let indices = IndexIterator::new(filter.filter_array(), count); + + // Copy values + // SAFETY: indices are derived from filter predicate + self.current + .extend(indices.map(|idx| unsafe { *values.get_unchecked(idx) })); + } else { + self.nulls.append_n_non_nulls(count); + let indices = IndexIterator::new(filter.filter_array(), count); + // SAFETY: indices are derived from filter predicate + self.current + .extend(indices.map(|idx: usize| unsafe { *values.get_unchecked(idx) })); + } + } + IterationStrategy::Indices(indices) => { + // Copy values and nulls using precomputed indices + if let Some(nulls) = s.nulls().filter(|n| n.null_count() > 0) { + let null_buffer = nulls.inner(); + let null_ptr = null_buffer.values().as_ptr(); + let null_offset = null_buffer.offset(); + + // Efficiently extend null buffer + // SAFETY: indices iterator reports correct length + unsafe { + self.nulls.extend_trusted_len( + indices + .iter() + .map(|&idx| bit_util::get_bit_raw(null_ptr, idx + null_offset)), + ); + } + + // Copy values + // SAFETY: indices are derived from filter predicate + self.current.extend( + indices + .iter() + .map(|&idx| unsafe { *values.get_unchecked(idx) }), + ); + } else { + self.nulls.append_n_non_nulls(count); + // SAFETY: indices are derived from filter predicate + self.current.extend( + indices + .iter() + .map(|&idx| unsafe { *values.get_unchecked(idx) }), + ) + }; + } + IterationStrategy::All => { + // Copy all values + self.current.extend_from_slice(values); + if let Some(nulls) = s.nulls() { + self.nulls.append_buffer(nulls); + } else { + self.nulls.append_n_non_nulls(values.len()); + } + } + IterationStrategy::None => { + // Nothing to copy + } + } + + Ok(()) + } + fn finish(&mut self) -> Result { // take and reset the current values and nulls let values = std::mem::take(&mut self.current); diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index e95d01f2b592..b45065545978 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -37,9 +37,9 @@ use arrow_schema::*; /// [`SlicesIterator`] to copy ranges of values. Otherwise iterate /// over individual rows using [`IndexIterator`] /// -/// Threshold of 0.8 chosen based on +/// Threshold of 0.9 chosen based on benchmarking results /// -const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.8; +const FILTER_SLICES_SELECTIVITY_THRESHOLD: f64 = 0.9; /// An iterator of `(usize, usize)` each representing an interval /// `[start, end)` whose slots of a bitmap [Buffer] are true. @@ -80,13 +80,17 @@ impl Iterator for SlicesIterator<'_> { /// /// This provides the best performance on most predicates, apart from those which keep /// large runs and therefore favour [`SlicesIterator`] -struct IndexIterator<'a> { +pub struct IndexIterator<'a> { remaining: usize, iter: BitIndexIterator<'a>, } impl<'a> IndexIterator<'a> { - fn new(filter: &'a BooleanArray, remaining: usize) -> Self { + /// Creates a new [`IndexIterator`] from a [`BooleanArray`] + /// + /// # Panics + /// Panics if `filter` has null values + pub fn new(filter: &'a BooleanArray, remaining: usize) -> Self { assert_eq!(filter.null_count(), 0); let iter = filter.values().set_indices(); Self { remaining, iter } @@ -216,6 +220,21 @@ pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result bool { + let num_cols = record_batch.num_columns(); + if num_cols > 1 { + return true; + } + if num_cols == 1 { + return FilterBuilder::is_optimize_beneficial( + record_batch.schema_ref().field(0).data_type(), + ); + } + false +} + /// Returns a filtered [RecordBatch] where the corresponding elements of /// `predicate` are true. /// @@ -231,13 +250,7 @@ pub fn filter_record_batch( predicate: &BooleanArray, ) -> Result { let mut filter_builder = FilterBuilder::new(predicate); - let num_cols = record_batch.num_columns(); - if num_cols > 1 - || (num_cols > 0 - && FilterBuilder::is_optimize_beneficial( - record_batch.schema_ref().field(0).data_type(), - )) - { + if is_optimize_beneficial_record_batch(record_batch) { // Only optimize if filtering more than one column or if the column contains multiple internal arrays // Otherwise, the overhead of optimization can be more than the benefit filter_builder = filter_builder.optimize(); @@ -325,15 +338,22 @@ impl FilterBuilder { } /// The iteration strategy used to evaluate [`FilterPredicate`] -#[derive(Debug)] -enum IterationStrategy { - /// A lazily evaluated iterator of ranges +/// +/// This determines how the filter will iterate over the selected rows. +/// The strategy is chosen based on the selectivity of the filter. +#[derive(Debug, Clone)] +pub enum IterationStrategy { + /// A lazily evaluated iterator of ranges (slices) + /// + /// Best for low selectivity filters (which select a relatively large number of rows) SlicesIterator, /// A lazily evaluated iterator of indices + /// + /// Best for high selectivity filters (which select a relatively low number of rows) IndexIterator, /// A precomputed list of indices Indices(Vec), - /// A precomputed array of ranges + /// A precomputed array of ranges (start, end) Slices(Vec<(usize, usize)>), /// Select all rows All, @@ -344,7 +364,13 @@ enum IterationStrategy { impl IterationStrategy { /// The default [`IterationStrategy`] for a filter of length `filter_length` /// and selecting `filter_count` rows - fn default_strategy(filter_length: usize, filter_count: usize) -> Self { + /// + /// Returns: + /// - [`IterationStrategy::None`] if `filter_count` is 0 + /// - [`IterationStrategy::All`] if `filter_count == filter_length` + /// - [`IterationStrategy::SlicesIterator`] if selectivity > 80% + /// - [`IterationStrategy::IndexIterator`] otherwise + pub fn default_strategy(filter_length: usize, filter_count: usize) -> Self { if filter_length == 0 || filter_count == 0 { return IterationStrategy::None; } @@ -363,6 +389,35 @@ impl IterationStrategy { } IterationStrategy::IndexIterator } + + fn slice(&self, offset: usize, len: usize) -> Self { + match self { + IterationStrategy::SlicesIterator => IterationStrategy::SlicesIterator, + IterationStrategy::IndexIterator => IterationStrategy::IndexIterator, + IterationStrategy::Indices(indices) => { + let start = indices.partition_point(|&idx| idx < offset); + let end = indices.partition_point(|&idx| idx < offset + len); + let new_indices = indices[start..end] + .iter() + .map(|&idx| idx - offset) + .collect(); + IterationStrategy::Indices(new_indices) + } + IterationStrategy::Slices(slices) => { + let mut new_slices = Vec::new(); + for &(start, end) in slices { + let max_start = start.max(offset); + let min_end = end.min(offset + len); + if max_start < min_end { + new_slices.push((max_start - offset, min_end - offset)); + } + } + IterationStrategy::Slices(new_slices) + } + IterationStrategy::All => IterationStrategy::All, + IterationStrategy::None => IterationStrategy::None, + } + } } /// A filtering predicate that can be applied to an [`Array`] @@ -408,6 +463,91 @@ impl FilterPredicate { pub fn count(&self) -> usize { self.count } + + /// Number of rows in the filter predicate + pub fn len(&self) -> usize { + self.filter.len() + } + + /// Slices this [`FilterPredicate`] + /// + /// # Panics + /// + /// Panics if `offset + len > self.len()` + pub fn slice(&self, offset: usize, len: usize) -> Self { + let filter = self.filter.slice(offset, len); + let count = filter.true_count(); + let strategy = self.strategy.slice(offset, len); + Self { + filter, + count, + strategy, + } + } + + /// Slices this [`FilterPredicate`] with a precomputed count + /// + /// # Panics + /// + /// Panics if `offset + len > self.len()` + pub fn slice_with_count(&self, offset: usize, len: usize, count: usize) -> Self { + let filter = self.filter.slice(offset, len); + let strategy = self.strategy.slice(offset, len); + Self { + filter, + count, + strategy, + } + } + + /// Returns the iteration strategy used by this [`FilterPredicate`] + pub fn strategy(&self) -> &IterationStrategy { + &self.strategy + } + + /// Returns the underlying filter array + pub fn filter_array(&self) -> &BooleanArray { + &self.filter + } + + /// Returns the bit position of the `n`-th set bit in the filter, starting the search at `start`. + /// + /// # Panics + /// + /// Panics if `n` bits are not found. + pub fn find_nth_set_bit_position(&self, start: usize, n: usize) -> usize { + if n == 0 { + return start; + } + + match &self.strategy { + IterationStrategy::Indices(indices) => { + // If we have precomputed indices, we can find the nth bit directly. + // Since this predicate might be a slice, the indices are relative to the start of this predicate. + // However, the `start` parameter is also relative to the start of this predicate. + let offset = indices.partition_point(|&idx| idx < start); + indices[offset + n - 1] + 1 + } + IterationStrategy::Slices(slices) => { + let mut remaining = n; + for &(s_start, s_end) in slices { + if s_end <= start { + continue; + } + let effective_start = s_start.max(start); + let len = s_end - effective_start; + if len >= remaining { + return effective_start + remaining; + } + remaining -= len; + } + panic!("n bits not found in slices") + } + IterationStrategy::All => start + n, + IterationStrategy::None => panic!("No bits in None strategy"), + _ => self.filter.values().find_nth_set_bit_position(start, n), + } + } } fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result {