Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 49 additions & 38 deletions vortex-array/src/arrays/list/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,41 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use vortex_buffer::BitBufferMut;
use vortex_dtype::{IntegerPType, Nullability};
use vortex_dtype::{IntegerPType, Nullability, match_each_integer_ptype};
use vortex_error::{VortexExpect, VortexResult, vortex_panic};
use vortex_mask::Mask;

use crate::arrays::{ListArray, ListVTable, PrimitiveArray, list_view_from_list};
use crate::arrays::{ListArray, ListVTable, PrimitiveArray};
use crate::builders::{ArrayBuilder, PrimitiveBuilder};
use crate::compute::{self, TakeKernel, TakeKernelAdapter};
use crate::compute::{TakeKernel, TakeKernelAdapter, take};
use crate::validity::Validity;
use crate::vtable::ValidityHelper;
use crate::{Array, ArrayRef, IntoArray, register_kernel};
use crate::{Array, ArrayRef, ToCanonical, register_kernel};

// TODO(connor)[ListView]: Re-revert to the version where we simply convert to a `ListView` and call
// the `ListView::take` compute function once `ListView` is more stable.

// TODO(connor): For very short arrays it is probably more efficient to build the list from scratch.
/// Take implementation for [`ListArray`].
///
/// This implementation converts the [`ListArray`] to a [`ListViewArray`] and then delegates to its
/// `take` implementation. This approach avoids the need to rebuild the `elements` array.
///
/// The resulting [`ListViewArray`] can represent non-contiguous and out-of-order lists, which would
/// violate [`ListArray`]'s invariants (but not [`ListViewArray`]'s).
///
/// [`ListViewArray`]: crate::arrays::ListViewArray
/// Unlike `ListView`, `ListArray` must rebuild the elements array to maintain its invariant
/// that lists are stored contiguously and in-order (`offset[i+1] >= offset[i]`). Taking
/// non-contiguous indices would violate this requirement.
impl TakeKernel for ListVTable {
fn take(&self, array: &ListArray, indices: &dyn Array) -> VortexResult<ArrayRef> {
let list_view = list_view_from_list(array.clone());
compute::take(&list_view.into_array(), indices)
let indices = indices.to_primitive();
let offsets = array.offsets().to_primitive();

match_each_integer_ptype!(offsets.dtype().as_ptype(), |O| {
match_each_integer_ptype!(indices.ptype(), |I| {
_take::<I, O>(
array,
offsets.as_slice::<O>(),
&indices,
array.validity_mask(),
indices.validity_mask(),
)
})
})
}
}

Expand Down Expand Up @@ -86,7 +96,7 @@ fn _take<I: IntegerPType, O: IntegerPType>(
let elements_to_take = elements_to_take.finish();
let new_offsets = new_offsets.finish();

let new_elements = compute::take(array.elements(), elements_to_take.as_ref())?;
let new_elements = take(array.elements(), elements_to_take.as_ref())?;

Ok(ListArray::try_new(
new_elements,
Expand Down Expand Up @@ -121,47 +131,48 @@ fn _take_nullable<I: IntegerPType, O: IntegerPType>(
let mut current_offset = O::zero();
new_offsets.append_zero();

let mut new_validity = BitBufferMut::with_capacity(indices.len());
// Set all bits to invalid and selectively set which values are valid.
let mut new_validity = BitBufferMut::new_unset(indices.len());

for (idx, data_idx) in indices.iter().enumerate() {
if !indices_validity.value(idx) {
new_offsets.append_value(current_offset);
new_validity.append_false();
// Bit buffer already has this set to invalid.
continue;
}

let data_idx = data_idx
.to_usize()
.unwrap_or_else(|| vortex_panic!("Failed to convert index to usize: {}", data_idx));

if data_validity.value(data_idx) {
let start = offsets[data_idx];
let stop = offsets[data_idx + 1];

// See the note it the `take` on the reasoning
let additional = (stop - start).to_usize().unwrap_or_else(|| {
vortex_panic!("Failed to convert range length to usize: {}", stop - start)
});

elements_to_take.reserve_exact(additional);
for i in 0..additional {
elements_to_take
.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
}
current_offset += stop - start;
if !data_validity.value(data_idx) {
new_offsets.append_value(current_offset);
new_validity.append_true()
} else {
new_offsets.append_value(current_offset);
new_validity.append_false();
// Bit buffer already has this set to invalid.
continue;
}

let start = offsets[data_idx];
let stop = offsets[data_idx + 1];

// See the note it the `take` on the reasoning
let additional = (stop - start).to_usize().unwrap_or_else(|| {
vortex_panic!("Failed to convert range length to usize: {}", stop - start)
});

elements_to_take.reserve_exact(additional);
for i in 0..additional {
elements_to_take.append_value(start + O::from_usize(i).vortex_expect("i < additional"));
}
current_offset += stop - start;
new_offsets.append_value(current_offset);
new_validity.set(idx);
}

let elements_to_take = elements_to_take.finish();
let new_offsets = new_offsets.finish();
let new_elements = compute::take(array.elements(), elements_to_take.as_ref())?;
let new_elements = take(array.elements(), elements_to_take.as_ref())?;

let new_validity: Validity = Validity::from(new_validity.freeze());
let new_validity = Validity::from(new_validity.freeze());
// data are indexes are nullable, so the final result is also nullable.

Ok(ListArray::try_new(new_elements, new_offsets, new_validity)?.to_array())
Expand Down
Loading