diff --git a/Cargo.lock b/Cargo.lock index 6147b39c41..750a11e768 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5173,6 +5173,7 @@ dependencies = [ name = "vortex-fuzz" version = "0.21.1" dependencies = [ + "arrow-buffer", "libfuzzer-sys", "vortex-array", "vortex-buffer", diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 4460e73d9b..925e178d62 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -19,6 +19,7 @@ cargo-fuzz = true [dependencies] libfuzzer-sys = { workspace = true } +arrow-buffer = { workspace = true } vortex-array = { workspace = true, features = ["arbitrary"] } vortex-buffer = { workspace = true } vortex-dtype = { workspace = true, features = ["arbitrary"] } diff --git a/fuzz/fuzz_targets/array_ops.rs b/fuzz/fuzz_targets/array_ops.rs index 283f21ca0c..d71fee4be1 100644 --- a/fuzz/fuzz_targets/array_ops.rs +++ b/fuzz/fuzz_targets/array_ops.rs @@ -3,7 +3,8 @@ use libfuzzer_sys::{fuzz_target, Corpus}; use vortex_array::aliases::hash_set::HashSet; use vortex_array::array::{ - BoolEncoding, PrimitiveEncoding, StructEncoding, VarBinEncoding, VarBinViewEncoding, + BoolEncoding, ListEncoding, PrimitiveEncoding, StructEncoding, VarBinEncoding, + VarBinViewEncoding, }; use vortex_array::compute::{ filter, scalar_at, search_sorted, slice, take, SearchResult, SearchSortedSide, @@ -48,6 +49,7 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus { &VarBinViewEncoding, &BoolEncoding, &StructEncoding, + &ListEncoding, ]) .contains(¤t_array.encoding()) { diff --git a/fuzz/src/filter.rs b/fuzz/src/filter.rs index 97eb0c3f79..9c2bb87ba8 100644 --- a/fuzz/src/filter.rs +++ b/fuzz/src/filter.rs @@ -7,6 +7,8 @@ use vortex_buffer::Buffer; use vortex_dtype::{match_each_native_ptype, DType}; use vortex_error::VortexExpect; +use crate::take::take_canonical_array; + pub fn filter_canonical_array(array: &ArrayData, filter: &[bool]) -> ArrayData { let validity = if array.dtype().is_nullable() { let validity_buff = array @@ -83,6 +85,15 @@ pub fn filter_canonical_array(array: &ArrayData, filter: &[bool]) -> ArrayData { .unwrap() .into_array() } + DType::List(..) => { + let mut indices = Vec::new(); + for (idx, bool) in filter.iter().enumerate() { + if *bool { + indices.push(idx); + } + } + take_canonical_array(array, &indices) + } _ => unreachable!("Not a canonical array"), } } diff --git a/fuzz/src/lib.rs b/fuzz/src/lib.rs index dd0ed3a78f..9b9ad229fe 100644 --- a/fuzz/src/lib.rs +++ b/fuzz/src/lib.rs @@ -6,12 +6,15 @@ mod take; use std::fmt::Debug; use std::iter; -use std::ops::Range; +use std::ops::{Range, RangeInclusive}; use libfuzzer_sys::arbitrary::Error::EmptyChoose; use libfuzzer_sys::arbitrary::{Arbitrary, Result, Unstructured}; pub use sort::sort_canonical_array; +use vortex_array::aliases::hash_set::HashSet; +use vortex_array::array::ListEncoding; use vortex_array::compute::{scalar_at, FilterMask, SearchResult, SearchSortedSide}; +use vortex_array::encoding::{Encoding, EncodingRef}; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; use vortex_buffer::Buffer; use vortex_sampling_compressor::SamplingCompressor; @@ -64,10 +67,13 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction { fn arbitrary(u: &mut Unstructured<'a>) -> Result { let array = ArrayData::arbitrary(u)?; let mut current_array = array.clone(); + + let valid_actions = actions_for_array(¤t_array); + let mut actions = Vec::new(); let action_count = u.int_in_range(1..=4)?; for _ in 0..action_count { - actions.push(match u.int_in_range(0..=4)? { + actions.push(match random_value_from_list(u, valid_actions.as_slice())? { 0 => { if actions .last() @@ -164,3 +170,28 @@ fn random_vec_in_range(u: &mut Unstructured<'_>, min: usize, max: usize) -> Resu }) .collect::>>() } + +fn random_value_from_list(u: &mut Unstructured<'_>, vec: &[usize]) -> Result { + u.choose_iter(vec).cloned() +} + +const ALL_ACTIONS: RangeInclusive = 0..=4; + +fn actions_for_encoding(encoding: EncodingRef) -> HashSet { + if ListEncoding::ID == encoding.id() { + // compress, slice and filter + vec![0, 1, 4].into_iter().collect() + } else { + ALL_ACTIONS.collect() + } +} + +fn actions_for_array(array: &ArrayData) -> Vec { + array + .depth_first_traversal() + .map(|child| actions_for_encoding(child.encoding())) + .fold(ALL_ACTIONS.collect::>(), |mut acc, actions| { + acc.retain(|a| actions.contains(a)); + acc + }) +} diff --git a/fuzz/src/search_sorted.rs b/fuzz/src/search_sorted.rs index 8c0f11fae7..24fdd7154e 100644 --- a/fuzz/src/search_sorted.rs +++ b/fuzz/src/search_sorted.rs @@ -121,6 +121,12 @@ pub fn search_sorted_canonical_array( .collect::>(); scalar_vals.search_sorted(&scalar.cast(array.dtype()).unwrap(), side) } + DType::List(..) => { + let scalar_vals = (0..array.len()) + .map(|i| scalar_at(array, i).unwrap()) + .collect::>(); + scalar_vals.search_sorted(&scalar.cast(array.dtype()).unwrap(), side) + } _ => unreachable!("Not a canonical array"), } } diff --git a/fuzz/src/slice.rs b/fuzz/src/slice.rs index 23d18609bc..1b46952bfc 100644 --- a/fuzz/src/slice.rs +++ b/fuzz/src/slice.rs @@ -1,9 +1,10 @@ +use arrow_buffer::ArrowNativeType; use vortex_array::accessor::ArrayAccessor; -use vortex_array::array::{BoolArray, PrimitiveArray, StructArray, VarBinViewArray}; +use vortex_array::array::{BoolArray, ListArray, PrimitiveArray, StructArray, VarBinViewArray}; use vortex_array::validity::{ArrayValidity, Validity}; -use vortex_array::variants::StructArrayTrait; -use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; -use vortex_dtype::{match_each_native_ptype, DType}; +use vortex_array::variants::{PrimitiveArrayTrait, StructArrayTrait}; +use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant}; +use vortex_dtype::{match_each_native_ptype, DType, NativePType}; use vortex_error::VortexExpect; pub fn slice_canonical_array(array: &ArrayData, start: usize, stop: usize) -> ArrayData { @@ -28,10 +29,12 @@ pub fn slice_canonical_array(array: &ArrayData, start: usize, stop: usize) -> Ar .vortex_expect("Validity length cannot mismatch") .into_array() } - DType::Primitive(p, _) => match_each_native_ptype!(p, |$P| { + DType::Primitive(p, _) => { let primitive_array = array.clone().into_primitive().unwrap(); - PrimitiveArray::new(primitive_array.buffer::<$P>().slice(start..stop), validity).into_array() - }), + match_each_native_ptype!(p, |$P| { + PrimitiveArray::new(primitive_array.buffer::<$P>().slice(start..stop), validity).into_array() + }) + } DType::Utf8(_) | DType::Binary(_) => { let utf8 = array.clone().into_varbinview().unwrap(); let values = utf8 @@ -55,6 +58,34 @@ pub fn slice_canonical_array(array: &ArrayData, start: usize, stop: usize) -> Ar .unwrap() .into_array() } + DType::List(..) => { + let list_array = array.clone().into_list().unwrap(); + let offsets = slice_canonical_array(&list_array.offsets(), start, stop + 1) + .into_primitive() + .unwrap(); + + let elements = slice_canonical_array( + &list_array.elements(), + offsets.get_as_cast::(0) as usize, + offsets.get_as_cast::(offsets.len() - 1) as usize, + ); + let offsets = match_each_native_ptype!(offsets.ptype(), |$P| { + shift_offsets::<$P>(offsets) + }) + .into_array(); + ListArray::try_new(elements, offsets, validity) + .unwrap() + .into_array() + } _ => unreachable!("Not a canonical array"), } } + +fn shift_offsets(offsets: PrimitiveArray) -> PrimitiveArray { + if offsets.is_empty() { + return offsets; + } + let offsets: Vec = offsets.as_slice().to_vec(); + let start = offsets[0]; + PrimitiveArray::from_iter(offsets.into_iter().map(|o| o - start).collect::>()) +} diff --git a/fuzz/src/sort.rs b/fuzz/src/sort.rs index a8bfc82a9b..50c0195216 100644 --- a/fuzz/src/sort.rs +++ b/fuzz/src/sort.rs @@ -70,7 +70,17 @@ pub fn sort_canonical_array(array: &ArrayData) -> ArrayData { }); take_canonical_array(array, &sort_indices) } - _ => unreachable!("Not a canonical array"), + DType::List(..) => { + let mut sort_indices = (0..array.len()).collect::>(); + sort_indices.sort_by(|a, b| { + scalar_at(array, *a) + .unwrap() + .partial_cmp(&scalar_at(array, *b).unwrap()) + .unwrap() + }); + take_canonical_array(array, &sort_indices) + } + a => unreachable!("Not a canonical array {:?}", a), } } diff --git a/fuzz/src/take.rs b/fuzz/src/take.rs index 256574fd69..080a989796 100644 --- a/fuzz/src/take.rs +++ b/fuzz/src/take.rs @@ -1,10 +1,13 @@ +use arrow_buffer::ArrowNativeType; use vortex_array::accessor::ArrayAccessor; use vortex_array::array::{BoolArray, PrimitiveArray, StructArray, VarBinViewArray}; +use vortex_array::builders::{builder_with_capacity, ArrayBuilderExt}; +use vortex_array::compute::scalar_at; use vortex_array::validity::{ArrayValidity, Validity}; use vortex_array::variants::StructArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant}; use vortex_buffer::Buffer; -use vortex_dtype::{match_each_native_ptype, DType}; +use vortex_dtype::{match_each_native_ptype, DType, NativePType}; use vortex_error::VortexExpect; pub fn take_canonical_array(array: &ArrayData, indices: &[usize]) -> ArrayData { @@ -31,16 +34,12 @@ pub fn take_canonical_array(array: &ArrayData, indices: &[usize]) -> ArrayData { .vortex_expect("Validity length cannot mismatch") .into_array() } - DType::Primitive(p, _) => match_each_native_ptype!(p, |$P| { + DType::Primitive(p, _) => { let primitive_array = array.clone().into_primitive().unwrap(); - let vec_values = primitive_array - .as_slice::<$P>() - .iter() - .copied() - .collect::>(); - PrimitiveArray::new(indices.iter().map(|i| vec_values[*i]).collect::>(), validity) - .into_array() - }), + match_each_native_ptype!(p, |$P| { + take_primitive::<$P>(primitive_array, validity, indices) + }) + } DType::Utf8(_) | DType::Binary(_) => { let utf8 = array.clone().into_varbinview().unwrap(); let values = utf8 @@ -68,6 +67,31 @@ pub fn take_canonical_array(array: &ArrayData, indices: &[usize]) -> ArrayData { .unwrap() .into_array() } + DType::List(..) => { + let mut builder = builder_with_capacity(array.dtype(), indices.len()); + for idx in indices { + builder + .append_scalar(&scalar_at(array, *idx).unwrap()) + .unwrap(); + } + builder.finish().unwrap() + } _ => unreachable!("Not a canonical array"), } } + +fn take_primitive( + primitive_array: PrimitiveArray, + validity: Validity, + indices: &[usize], +) -> ArrayData { + let vec_values = primitive_array.as_slice::().to_vec(); + PrimitiveArray::new( + indices + .iter() + .map(|i| vec_values[*i]) + .collect::>(), + validity, + ) + .into_array() +} diff --git a/vortex-array/src/array/arbitrary.rs b/vortex-array/src/array/arbitrary.rs index 462e4a51bc..60f6d081b7 100644 --- a/vortex-array/src/array/arbitrary.rs +++ b/vortex-array/src/array/arbitrary.rs @@ -88,7 +88,7 @@ fn random_array(u: &mut Unstructured, dtype: &DType, len: Option) -> Resu .vortex_unwrap() .into_array()) } - DType::List(ldt, n) => random_list(u, ldt, n), + DType::List(ldt, n) => random_list(u, ldt, n, chunk_len), DType::Extension(..) => { todo!("Extension arrays are not implemented") } @@ -106,14 +106,19 @@ fn random_array(u: &mut Unstructured, dtype: &DType, len: Option) -> Resu } } -fn random_list(u: &mut Unstructured, ldt: &Arc, n: &Nullability) -> Result { +fn random_list( + u: &mut Unstructured, + ldt: &Arc, + n: &Nullability, + chunk_len: Option, +) -> Result { match u.int_in_range(0..=5)? { - 0 => random_list_offset::(u, ldt, n), - 1 => random_list_offset::(u, ldt, n), - 2 => random_list_offset::(u, ldt, n), - 3 => random_list_offset::(u, ldt, n), - 4 => random_list_offset::(u, ldt, n), - 5 => random_list_offset::(u, ldt, n), + 0 => random_list_offset::(u, ldt, n, chunk_len), + 1 => random_list_offset::(u, ldt, n, chunk_len), + 2 => random_list_offset::(u, ldt, n, chunk_len), + 3 => random_list_offset::(u, ldt, n, chunk_len), + 4 => random_list_offset::(u, ldt, n, chunk_len), + 5 => random_list_offset::(u, ldt, n, chunk_len), _ => unreachable!("int_in_range returns a value in the above range"), } } @@ -122,14 +127,15 @@ fn random_list_offset( u: &mut Unstructured, ldt: &Arc, n: &Nullability, + chunk_len: Option, ) -> Result where O: PrimInt + NativePType, Scalar: From, usize: AsPrimitive, { - let list_len = u.int_in_range(0..=20)?; - let mut builder = ListBuilder::::with_capacity(ldt.clone(), *n, 1); + let list_len = chunk_len.unwrap_or(u.int_in_range(0..=20)?); + let mut builder = ListBuilder::::with_capacity(ldt.clone(), *n, 10); for _ in 0..list_len { if matches!(n, Nullability::Nullable) || u.arbitrary::()? { let elem_len = u.int_in_range(0..=20)?; diff --git a/vortex-array/src/array/list/mod.rs b/vortex-array/src/array/list/mod.rs index 04df958da5..14255e541d 100644 --- a/vortex-array/src/array/list/mod.rs +++ b/vortex-array/src/array/list/mod.rs @@ -225,6 +225,7 @@ impl ListArray { mod test { use std::sync::Arc; + use arrow_buffer::BooleanBuffer; use vortex_dtype::Nullability; use vortex_dtype::Nullability::NonNullable; use vortex_dtype::PType::I32; @@ -232,7 +233,7 @@ mod test { use crate::array::list::ListArray; use crate::array::PrimitiveArray; - use crate::compute::scalar_at; + use crate::compute::{filter, scalar_at, FilterMask}; use crate::validity::Validity; use crate::{ArrayLen, IntoArrayData}; @@ -301,4 +302,22 @@ mod test { scalar_at(&list_from_iter, 1).unwrap() ); } + + #[test] + fn test_simple_list_filter() { + let elements = PrimitiveArray::from_option_iter([None, Some(2), Some(3), Some(4), Some(5)]); + let offsets = PrimitiveArray::from_iter([0, 2, 4, 5]); + let validity = Validity::AllValid; + + let list = ListArray::try_new(elements.into_array(), offsets.into_array(), validity) + .unwrap() + .into_array(); + + let filtered = filter( + &list, + FilterMask::from(BooleanBuffer::from(vec![false, true, true])), + ); + + assert!(filtered.is_ok()) + } } diff --git a/vortex-array/src/arrow/array.rs b/vortex-array/src/arrow/array.rs index 77b6b1534a..e46d920ac6 100644 --- a/vortex-array/src/arrow/array.rs +++ b/vortex-array/src/arrow/array.rs @@ -188,8 +188,14 @@ impl FromArrowArray<&ArrowStructArray> for ArrayData { impl FromArrowArray<&GenericListArray> for ArrayData { fn from_arrow(value: &GenericListArray, nullable: bool) -> Self { + // Extract the validity of the underlying element array + let elem_nullable = match value.data_type() { + DataType::List(field) => field.is_nullable(), + DataType::LargeList(field) => field.is_nullable(), + dt => vortex_panic!("Invalid data type for ListArray: {dt}"), + }; ListArray::try_new( - Self::from_arrow(value.values().clone(), value.values().is_nullable()), + Self::from_arrow(value.values().clone(), elem_nullable), // offsets are always non-nullable ArrayData::from(value.offsets().clone()), nulls(value.nulls(), nullable), diff --git a/vortex-dtype/src/arbitrary.rs b/vortex-dtype/src/arbitrary.rs index 161141dae6..98095d499d 100644 --- a/vortex-dtype/src/arbitrary.rs +++ b/vortex-dtype/src/arbitrary.rs @@ -11,15 +11,21 @@ impl<'a> Arbitrary<'a> for DType { } fn random_dtype(u: &mut Unstructured<'_>, depth: u8) -> Result { - let max_dtype_kind = if depth == 0 { 3 } else { 4 }; - Ok(match u.int_in_range(0..=max_dtype_kind)? { - 0 => DType::Bool(u.arbitrary()?), - 1 => DType::Primitive(u.arbitrary()?, u.arbitrary()?), - 2 => DType::Utf8(u.arbitrary()?), - 3 => DType::Binary(u.arbitrary()?), - 4 => DType::Struct(random_struct_dtype(u, depth - 1)?, u.arbitrary()?), + const BASE_TYPE_COUNT: i32 = 4; + const CONTAINER_TYPE_COUNT: i32 = 2; + let max_dtype_kind = if depth == 0 { + BASE_TYPE_COUNT + } else { + CONTAINER_TYPE_COUNT + BASE_TYPE_COUNT + }; + Ok(match u.int_in_range(1..=max_dtype_kind)? { + 1 => DType::Bool(u.arbitrary()?), + 2 => DType::Primitive(u.arbitrary()?, u.arbitrary()?), + 3 => DType::Utf8(u.arbitrary()?), + 4 => DType::Binary(u.arbitrary()?), + 5 => DType::Struct(random_struct_dtype(u, depth - 1)?, u.arbitrary()?), + 6 => DType::List(Arc::new(random_dtype(u, depth - 1)?), u.arbitrary()?), // Null, - // List(Arc, Nullability), // Extension(ExtDType, Nullability), _ => unreachable!("Number out of range"), })