diff --git a/arrow/src/compute/kernels/filter.rs b/arrow/src/compute/kernels/filter.rs index 68feb0a546e2..4da07b89edde 100644 --- a/arrow/src/compute/kernels/filter.rs +++ b/arrow/src/compute/kernels/filter.rs @@ -17,6 +17,8 @@ //! Defines miscellaneous array kernels. +use crate::buffer::buffer_bin_and; +use crate::datatypes::DataType; use crate::error::Result; use crate::record_batch::RecordBatch; use crate::{array::*, util::bit_chunk_iterator::BitChunkIterator}; @@ -204,8 +206,7 @@ pub fn build_filter(filter: &BooleanArray) -> Result { } /// Filters an [Array], returning elements matching the filter (i.e. where the values are true). -/// WARNING: the nulls of `filter` are ignored and the value on its slot is considered. -/// Therefore, it is considered undefined behavior to pass `filter` with null values. +/// /// # Example /// ```rust /// # use arrow::array::{Int32Array, BooleanArray}; @@ -221,6 +222,25 @@ pub fn build_filter(filter: &BooleanArray) -> Result { /// # } /// ``` pub fn filter(array: &Array, filter: &BooleanArray) -> Result { + if filter.null_count() > 0 { + // this greatly simplifies subsequent filtering code + // now we only have a boolean mask to deal with + let array_data = filter.data_ref(); + let null_bitmap = array_data.null_buffer().unwrap(); + let mask = filter.values(); + let offset = filter.offset(); + + let new_mask = buffer_bin_and(mask, offset, null_bitmap, offset, filter.len()); + + let array_data = ArrayData::builder(DataType::Boolean) + .len(filter.len()) + .add_buffer(new_mask) + .build(); + let filter = BooleanArray::from(array_data); + // fully qualified syntax, because we have an argument with the same name + return crate::compute::kernels::filter::filter(array, &filter); + } + let iter = SlicesIterator::new(filter); let mut mutable = @@ -249,6 +269,7 @@ pub fn filter_record_batch( #[cfg(test)] mod tests { use super::*; + use crate::datatypes::Int64Type; use crate::{ buffer::Buffer, datatypes::{DataType, Field}, @@ -581,4 +602,27 @@ mod tests { assert_eq!(chunks, vec![(1, 62), (63, 124), (125, 130)]); assert_eq!(filter_count, 61 + 61 + 5); } + + #[test] + fn test_null_mask() -> Result<()> { + use crate::compute::kernels::comparison; + let a: PrimitiveArray = + PrimitiveArray::from(vec![Some(1), Some(2), None]); + let mask0 = comparison::eq(&a, &a)?; + let out0 = filter(&a, &mask0)?; + let out_arr0 = out0 + .as_any() + .downcast_ref::>() + .unwrap(); + + let mask1 = BooleanArray::from(vec![Some(true), Some(true), None]); + let out1 = filter(&a, &mask1)?; + let out_arr1 = out1 + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(mask0, mask1); + assert_eq!(out_arr0, out_arr1); + Ok(()) + } }