diff --git a/encodings/sequence/src/compute/filter.rs b/encodings/sequence/src/compute/filter.rs new file mode 100644 index 00000000000..757ce093807 --- /dev/null +++ b/encodings/sequence/src/compute/filter.rs @@ -0,0 +1,68 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright the Vortex contributors + +use vortex_array::arrays::PrimitiveArray; +use vortex_array::compute::{FilterKernel, FilterKernelAdapter}; +use vortex_array::validity::Validity; +use vortex_array::{ArrayRef, IntoArray, register_kernel}; +use vortex_buffer::BufferMut; +use vortex_dtype::{NativePType, match_each_native_ptype}; +use vortex_error::{VortexExpect, VortexResult}; +use vortex_mask::{AllOr, Mask}; + +use crate::{SequenceArray, SequenceVTable}; + +impl FilterKernel for SequenceVTable { + fn filter(&self, array: &SequenceArray, selection_mask: &Mask) -> VortexResult { + let validity = Validity::from(array.dtype().nullability()); + match_each_native_ptype!(array.ptype(), |P| { + let mul = array.multiplier().as_primitive::

(); + let base = array.base().as_primitive::

(); + Ok(filter_impl(mul, base, selection_mask, validity)) + }) + } +} + +register_kernel!(FilterKernelAdapter(SequenceVTable).lift()); + +fn filter_impl(mul: T, base: T, mask: &Mask, validity: Validity) -> ArrayRef { + match mask.boolean_buffer() { + AllOr::All | AllOr::None => unreachable!("Handled by entrypoint function"), + AllOr::Some(mask) => { + let mut buffer = BufferMut::::with_capacity(mask.count_set_bits()); + buffer.extend(mask.set_indices().map(|idx| { + let i = T::from_usize(idx).vortex_expect("all valid indices fit"); + base + i * mul + })); + PrimitiveArray::new(buffer.freeze(), validity).into_array() + } + } +} + +#[cfg(test)] +mod tests { + use rstest::rstest; + use vortex_array::compute::conformance::filter::{ + LARGE_SIZE, MEDIUM_SIZE, test_filter_conformance, + }; + use vortex_dtype::Nullability; + + use crate::SequenceArray; + + #[rstest] + #[case(SequenceArray::typed_new(0i32, 1, Nullability::NonNullable, 5).unwrap())] + #[case(SequenceArray::typed_new(10i32, 2, Nullability::NonNullable, 5).unwrap())] + #[case(SequenceArray::typed_new(100i32, -3, Nullability::NonNullable, 5).unwrap())] + #[case(SequenceArray::typed_new(0i32, 1, Nullability::NonNullable, 1).unwrap())] + #[case(SequenceArray::typed_new(0i32, 1, Nullability::NonNullable, MEDIUM_SIZE).unwrap())] + #[case(SequenceArray::typed_new(0i32, 1, Nullability::NonNullable, LARGE_SIZE).unwrap())] + #[case(SequenceArray::typed_new(0i64, 1, Nullability::NonNullable, 5).unwrap())] + #[case(SequenceArray::typed_new(1000i64, 50, Nullability::NonNullable, 5).unwrap())] + #[case(SequenceArray::typed_new(-100i64, 10, Nullability::NonNullable, MEDIUM_SIZE).unwrap())] + #[case(SequenceArray::typed_new(0u32, 1, Nullability::NonNullable, 5).unwrap())] + #[case(SequenceArray::typed_new(0u32, 5, Nullability::NonNullable, MEDIUM_SIZE).unwrap())] + #[case(SequenceArray::typed_new(0u64, 1, Nullability::NonNullable, LARGE_SIZE).unwrap())] + fn test_filter_sequence_conformance(#[case] array: SequenceArray) { + test_filter_conformance(array.as_ref()); + } +} diff --git a/encodings/sequence/src/compute/mod.rs b/encodings/sequence/src/compute/mod.rs index 3a9adf42bf5..b3d35a88b20 100644 --- a/encodings/sequence/src/compute/mod.rs +++ b/encodings/sequence/src/compute/mod.rs @@ -3,6 +3,7 @@ mod cast; mod compare; +mod filter; mod is_sorted; mod list_contains; mod min_max;