diff --git a/vortex-array/src/arrays/list/compute/take.rs b/vortex-array/src/arrays/list/compute/take.rs index 31dbc60a864..040400977e6 100644 --- a/vortex-array/src/arrays/list/compute/take.rs +++ b/vortex-array/src/arrays/list/compute/take.rs @@ -2,31 +2,41 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use vortex_buffer::BitBufferMut; -use vortex_dtype::{IntegerPType, Nullability}; +use vortex_dtype::{IntegerPType, Nullability, match_each_integer_ptype}; use vortex_error::{VortexExpect, VortexResult, vortex_panic}; use vortex_mask::Mask; -use crate::arrays::{ListArray, ListVTable, PrimitiveArray, list_view_from_list}; +use crate::arrays::{ListArray, ListVTable, PrimitiveArray}; use crate::builders::{ArrayBuilder, PrimitiveBuilder}; -use crate::compute::{self, TakeKernel, TakeKernelAdapter}; +use crate::compute::{TakeKernel, TakeKernelAdapter, take}; use crate::validity::Validity; use crate::vtable::ValidityHelper; -use crate::{Array, ArrayRef, IntoArray, register_kernel}; +use crate::{Array, ArrayRef, ToCanonical, register_kernel}; + +// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call +// the `ListView::take` compute function once `ListView` is more stable. -// TODO(connor): For very short arrays it is probably more efficient to build the list from scratch. /// Take implementation for [`ListArray`]. /// -/// This implementation converts the [`ListArray`] to a [`ListViewArray`] and then delegates to its -/// `take` implementation. This approach avoids the need to rebuild the `elements` array. -/// -/// The resulting [`ListViewArray`] can represent non-contiguous and out-of-order lists, which would -/// violate [`ListArray`]'s invariants (but not [`ListViewArray`]'s). -/// -/// [`ListViewArray`]: crate::arrays::ListViewArray +/// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant +/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking +/// non-contiguous indices would violate this requirement. impl TakeKernel for ListVTable { fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult { - let list_view = list_view_from_list(array.clone()); - compute::take(&list_view.into_array(), indices) + let indices = indices.to_primitive(); + let offsets = array.offsets().to_primitive(); + + match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| { + match_each_integer_ptype!(indices.ptype(), |I| { + _take::( + array, + offsets.as_slice::(), + &indices, + array.validity_mask(), + indices.validity_mask(), + ) + }) + }) } } @@ -86,7 +96,7 @@ fn _take( let elements_to_take = elements_to_take.finish(); let new_offsets = new_offsets.finish(); - let new_elements = compute::take(array.elements(), elements_to_take.as_ref())?; + let new_elements = take(array.elements(), elements_to_take.as_ref())?; Ok(ListArray::try_new( new_elements, @@ -121,12 +131,13 @@ fn _take_nullable( let mut current_offset = O::zero(); new_offsets.append_zero(); - let mut new_validity = BitBufferMut::with_capacity(indices.len()); + // Set all bits to invalid and selectively set which values are valid. + let mut new_validity = BitBufferMut::new_unset(indices.len()); for (idx, data_idx) in indices.iter().enumerate() { if !indices_validity.value(idx) { new_offsets.append_value(current_offset); - new_validity.append_false(); + // Bit buffer already has this set to invalid. continue; } @@ -134,34 +145,34 @@ fn _take_nullable( .to_usize() .unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx)); - if data_validity.value(data_idx) { - let start = offsets[data_idx]; - let stop = offsets[data_idx + 1]; - - // See the note it the `take` on the reasoning - let additional = (stop - start).to_usize().unwrap_or_else(|| { - vortex_panic!("Failed to convert range length to usize: {}", stop - start) - }); - - elements_to_take.reserve_exact(additional); - for i in 0..additional { - elements_to_take - .append_value(start + O::from_usize(i).vortex_expect("i < additional")); - } - current_offset += stop - start; + if !data_validity.value(data_idx) { new_offsets.append_value(current_offset); - new_validity.append_true() - } else { - new_offsets.append_value(current_offset); - new_validity.append_false(); + // Bit buffer already has this set to invalid. + continue; } + + let start = offsets[data_idx]; + let stop = offsets[data_idx + 1]; + + // See the note it the `take` on the reasoning + let additional = (stop - start).to_usize().unwrap_or_else(|| { + vortex_panic!("Failed to convert range length to usize: {}", stop - start) + }); + + elements_to_take.reserve_exact(additional); + for i in 0..additional { + elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional")); + } + current_offset += stop - start; + new_offsets.append_value(current_offset); + new_validity.set(idx); } let elements_to_take = elements_to_take.finish(); let new_offsets = new_offsets.finish(); - let new_elements = compute::take(array.elements(), elements_to_take.as_ref())?; + let new_elements = take(array.elements(), elements_to_take.as_ref())?; - let new_validity: Validity = Validity::from(new_validity.freeze()); + let new_validity = Validity::from(new_validity.freeze()); // data are indexes are nullable, so the final result is also nullable. Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())