Skip to content

Commit

Permalink
Add specialized primitive filter kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Jan 30, 2022
1 parent f809d33 commit 8ed556e
Showing 1 changed file with 303 additions and 20 deletions.
323 changes: 303 additions & 20 deletions arrow/src/compute/kernels/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,38 @@
//! Defines miscellaneous array kernels.

use crate::array::*;
use crate::buffer::buffer_bin_and;
use crate::datatypes::DataType;
use crate::error::Result;
use crate::buffer::{buffer_bin_and, Buffer, MutableBuffer};
use crate::datatypes::*;
use crate::error::{ArrowError, Result};
use crate::record_batch::RecordBatch;
use crate::util::bit_chunk_iterator::{UnalignedBitChunk, UnalignedBitChunkIterator};
use crate::util::bit_util;
use std::sync::Arc;
use TimeUnit::*;

macro_rules! downcast_filter {
($type: ty, $values: expr, $filter: expr, $filter_count: expr) => {{
let values = $values
.as_any()
.downcast_ref::<PrimitiveArray<$type>>()
.expect("Unable to downcast to a primitive array");

Ok(Arc::new(filter_primitive::<$type>(
&values,
$filter,
$filter_count,
)))
}};
}

/// Function that can filter arbitrary arrays
pub type Filter<'a> = Box<dyn Fn(&ArrayData) -> ArrayData + 'a>;

/// An iterator of `(usize, usize)` each representing an interval `[start,end[` whose
/// slots of a [BooleanArray] are true. Each interval corresponds to a contiguous region of memory to be
/// "taken" from an array to be filtered.
///
/// This is most performant for highly selective filters with long contiguous runs
#[derive(Debug)]
pub struct SlicesIterator<'a> {
iter: UnalignedBitChunkIterator<'a>,
Expand Down Expand Up @@ -114,6 +134,53 @@ impl<'a> Iterator for SlicesIterator<'a> {
}
}

/// An iterator of `usize` whose index in [`BooleanArray`] is true
///
/// This provides the best performance on all but the most selective predicates, where the
/// benefits of copying large runs instead favours [`SlicesIterator`]
struct IndexIterator<'a> {
current_chunk: u64,
chunk_end_offset: usize,
iter: UnalignedBitChunkIterator<'a>,
}

impl<'a> IndexIterator<'a> {
fn new(filter: &'a BooleanArray) -> Self {
assert_eq!(filter.null_count(), 0);
let data = filter.data();
let chunks =
UnalignedBitChunk::new(&data.buffers()[0], data.offset(), data.len());
let mut iter = chunks.iter();

let current_chunk = iter.next().unwrap_or(0);
let chunk_end_offset = 64 - chunks.lead_padding();

Self {
current_chunk,
chunk_end_offset,
iter,
}
}
}

impl<'a> Iterator for IndexIterator<'a> {
type Item = usize;

fn next(&mut self) -> Option<Self::Item> {
loop {
if self.current_chunk != 0 {
let bit_pos = self.current_chunk.trailing_zeros();
self.current_chunk ^= 1 << bit_pos;
return Some(self.chunk_end_offset + (bit_pos as usize) - 64);
}

self.current_chunk = self.iter.next()?;
self.chunk_end_offset += 64;
}
}
}

/// Counts the number of set bits in `filter`
fn filter_count(filter: &BooleanArray) -> usize {
filter
.values()
Expand Down Expand Up @@ -180,37 +247,146 @@ pub fn prep_null_mask_filter(filter: &BooleanArray) -> BooleanArray {
/// # Ok(())
/// # }
/// ```
pub fn filter(array: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef> {
pub fn filter(values: &dyn Array, predicate: &BooleanArray) -> Result<ArrayRef> {
if predicate.null_count() > 0 {
// this greatly simplifies subsequent filtering code
// now we only have a boolean mask to deal with
let predicate = prep_null_mask_filter(predicate);
return filter(array, &predicate);
return filter(values, &predicate);
}

if predicate.len() > values.len() {
return Err(ArrowError::InvalidArgumentError(format!(
"Filter predicate of length {} is larger than target array of length {}",
predicate.len(),
values.len()
)));
}

let filter_count = filter_count(predicate);

match filter_count {
0 => {
// return empty
Ok(new_empty_array(array.data_type()))
Ok(new_empty_array(values.data_type()))
}
len if len == array.len() => {
len if len == values.len() => {
// return all
let data = array.data().clone();
let data = values.data().clone();
Ok(make_array(data))
}
_ => {
// actually filter
let mut mutable =
MutableArrayData::new(vec![array.data_ref()], false, filter_count);
// actually filter
_ => match values.data_type() {
DataType::Boolean => {
let values = values.as_any().downcast_ref::<BooleanArray>().unwrap();
Ok(Arc::new(filter_boolean(values, predicate, filter_count)))
}
DataType::Int8 => downcast_filter!(Int8Type, values, predicate, filter_count),
DataType::Int16 => {
downcast_filter!(Int16Type, values, predicate, filter_count)
}
DataType::Int32 => {
downcast_filter!(Int32Type, values, predicate, filter_count)
}
DataType::Int64 => {
downcast_filter!(Int64Type, values, predicate, filter_count)
}
DataType::UInt8 => {
downcast_filter!(UInt8Type, values, predicate, filter_count)
}
DataType::UInt16 => {
downcast_filter!(UInt16Type, values, predicate, filter_count)
}
DataType::UInt32 => {
downcast_filter!(UInt32Type, values, predicate, filter_count)
}
DataType::UInt64 => {
downcast_filter!(UInt64Type, values, predicate, filter_count)
}
DataType::Float32 => {
downcast_filter!(Float32Type, values, predicate, filter_count)
}
DataType::Float64 => {
downcast_filter!(Float64Type, values, predicate, filter_count)
}
DataType::Date32 => {
downcast_filter!(Date32Type, values, predicate, filter_count)
}
DataType::Date64 => {
downcast_filter!(Date64Type, values, predicate, filter_count)
}
DataType::Time32(Second) => {
downcast_filter!(Time32SecondType, values, predicate, filter_count)
}
DataType::Time32(Millisecond) => {
downcast_filter!(Time32MillisecondType, values, predicate, filter_count)
}
DataType::Time64(Microsecond) => {
downcast_filter!(Time64MicrosecondType, values, predicate, filter_count)
}
DataType::Time64(Nanosecond) => {
downcast_filter!(Time64NanosecondType, values, predicate, filter_count)
}
DataType::Timestamp(Second, _) => {
downcast_filter!(TimestampSecondType, values, predicate, filter_count)
}
DataType::Timestamp(Millisecond, _) => {
downcast_filter!(
TimestampMillisecondType,
values,
predicate,
filter_count
)
}
DataType::Timestamp(Microsecond, _) => {
downcast_filter!(
TimestampMicrosecondType,
values,
predicate,
filter_count
)
}
DataType::Timestamp(Nanosecond, _) => {
downcast_filter!(TimestampNanosecondType, values, predicate, filter_count)
}
DataType::Interval(IntervalUnit::YearMonth) => {
downcast_filter!(IntervalYearMonthType, values, predicate, filter_count)
}
DataType::Interval(IntervalUnit::DayTime) => {
downcast_filter!(IntervalDayTimeType, values, predicate, filter_count)
}
DataType::Interval(IntervalUnit::MonthDayNano) => {
downcast_filter!(
IntervalMonthDayNanoType,
values,
predicate,
filter_count
)
}
DataType::Duration(TimeUnit::Second) => {
downcast_filter!(DurationSecondType, values, predicate, filter_count)
}
DataType::Duration(TimeUnit::Millisecond) => {
downcast_filter!(DurationMillisecondType, values, predicate, filter_count)
}
DataType::Duration(TimeUnit::Microsecond) => {
downcast_filter!(DurationMicrosecondType, values, predicate, filter_count)
}
DataType::Duration(TimeUnit::Nanosecond) => {
downcast_filter!(DurationNanosecondType, values, predicate, filter_count)
}
_ => {
// fallback to using MutableArrayData
let mut mutable =
MutableArrayData::new(vec![values.data_ref()], false, filter_count);

let iter = SlicesIterator::new(predicate);
iter.for_each(|(start, end)| mutable.extend(0, start, end));
let iter = SlicesIterator::new(predicate);
iter.for_each(|(start, end)| mutable.extend(0, start, end));

let data = mutable.freeze();
Ok(make_array(data))
}
let data = mutable.freeze();
Ok(make_array(data))
}
},
}
}

Expand Down Expand Up @@ -244,6 +420,110 @@ pub fn filter_record_batch(
RecordBatch::try_new(record_batch.schema(), filtered_arrays)
}

fn filter_null_mask(
data: &ArrayData,
filter: &BooleanArray,
filter_count: usize,
) -> Option<(usize, Buffer)> {
if data.null_count() == 0 {
return None;
}

let nulls = filter_bits(data.null_buffer()?, data.offset(), filter, filter_count);
let null_count = filter_count - nulls.count_set_bits();

if null_count == 0 {
return None;
}

Some((null_count, nulls))
}

fn filter_bits(
buffer: &Buffer,
offset: usize,
filter: &BooleanArray,
filter_count: usize,
) -> Buffer {
let src = buffer.as_slice();

// TODO: Optimise this
let mut buf = MutableBuffer::from_len_zeroed(bit_util::ceil(filter_count, 8));
let dst = buf.as_slice_mut();

for (dst_idx, src_idx) in IndexIterator::new(filter).enumerate() {
if bit_util::get_bit(src, src_idx + offset) {
bit_util::set_bit(dst, dst_idx);
}
}

buf.into()
}

/// `filter` implementation for boolean buffers
fn filter_boolean(
values: &BooleanArray,
filter: &BooleanArray,
filter_count: usize,
) -> BooleanArray {
let data = values.data();
assert_eq!(data.buffers().len(), 1);
assert_eq!(data.child_data().len(), 0);

let values = filter_bits(&data.buffers()[0], data.offset(), filter, filter_count);

let mut builder = ArrayDataBuilder::new(DataType::Boolean)
.len(filter_count)
.add_buffer(values);

if let Some((null_count, nulls)) = filter_null_mask(data, filter, filter_count) {
builder = builder.null_count(null_count).null_bit_buffer(nulls);
}

let data = builder.build().unwrap(); // TODO: unsafe { builder.build_unchecked() };
BooleanArray::from(data)
}

/// `filter` implementation for primitive arrays
fn filter_primitive<T>(
values: &PrimitiveArray<T>,
filter: &BooleanArray,
filter_count: usize,
) -> PrimitiveArray<T>
where
T: ArrowPrimitiveType,
{
let data = values.data();
assert_eq!(data.buffers().len(), 1);
assert_eq!(data.child_data().len(), 0);

let values = data.buffer::<T::Native>(0);

let mut buffer = MutableBuffer::with_capacity(filter_count * T::get_byte_width());

let selectivity_frac = filter_count as f64 / filter.len() as f64;
if selectivity_frac > 0.8 {
for (start, end) in SlicesIterator::new(filter) {
buffer.extend_from_slice(&values[start..end]);
}
} else {
for idx in IndexIterator::new(filter) {
unsafe { buffer.push_unchecked(values[idx]) };
}
}

let mut builder = ArrayDataBuilder::new(data.data_type().clone())
.len(filter_count)
.add_buffer(buffer.into());

if let Some((null_count, nulls)) = filter_null_mask(data, filter, filter_count) {
builder = builder.null_count(null_count).null_bit_buffer(nulls);
}

let data = unsafe { builder.build_unchecked() };
PrimitiveArray::from(data)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -649,12 +929,14 @@ mod tests {
.build()
.unwrap();

let bool_array = BooleanArray::from(data);
let filter = BooleanArray::from(data);

let bits: Vec<_> = SlicesIterator::new(&bool_array)
let slice_bits: Vec<_> = SlicesIterator::new(&filter)
.flat_map(|(start, end)| start..end)
.collect();

let index_bits: Vec<_> = IndexIterator::new(&filter).collect();

let expected_bits: Vec<_> = bools
.iter()
.skip(offset)
Expand All @@ -663,7 +945,8 @@ mod tests {
.flat_map(|(idx, v)| v.then(|| idx))
.collect();

assert_eq!(bits, expected_bits);
assert_eq!(slice_bits, expected_bits);
assert_eq!(index_bits, expected_bits);
}

#[test]
Expand Down

0 comments on commit 8ed556e

Please sign in to comment.