Skip to content

Commit

Permalink
Merge 541b5ad into f1a831f
Browse files Browse the repository at this point in the history
  • Loading branch information
jhorstmann authored Jun 30, 2021
2 parents f1a831f + 541b5ad commit 2423fad
Showing 1 changed file with 88 additions and 3 deletions.
91 changes: 88 additions & 3 deletions arrow/src/compute/kernels/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,13 @@ where
// Soundness: `slice.map` is `TrustedLen`.
let buffer = unsafe { Buffer::try_from_trusted_len_iter(values)? };

Ok((buffer, indices.data_ref().null_buffer().cloned()))
Ok((
buffer,
indices
.data_ref()
.null_buffer()
.map(|b| b.bit_slice(indices.offset(), indices.len())),
))
}

// take implementation when both values and indices contain nulls
Expand Down Expand Up @@ -516,7 +522,7 @@ where
nulls = match indices.data_ref().null_buffer() {
Some(buffer) => Some(buffer_bin_and(
buffer,
0,
indices.offset(),
&null_buf.into(),
0,
indices.len(),
Expand Down Expand Up @@ -805,6 +811,24 @@ mod tests {
Ok(())
}

fn test_take_primitive_arrays_non_null<T>(
data: Vec<T::Native>,
index: &UInt32Array,
options: Option<TakeOptions>,
expected_data: Vec<Option<T::Native>>,
) -> Result<()>
where
T: ArrowPrimitiveType,
PrimitiveArray<T>: From<Vec<T::Native>>,
PrimitiveArray<T>: From<Vec<Option<T::Native>>>,
{
let output = PrimitiveArray::<T>::from(data);
let expected = Arc::new(PrimitiveArray::<T>::from(expected_data)) as ArrayRef;
let output = take(&output, index, options)?;
assert_eq!(&output, &expected);
Ok(())
}

fn test_take_impl_primitive_arrays<T, I>(
data: Vec<Option<T::Native>>,
index: &PrimitiveArray<I>,
Expand Down Expand Up @@ -876,6 +900,48 @@ mod tests {
.unwrap();
}

#[test]
fn test_take_primitive_nullable_indices_non_null_values_with_offset() {
let index =
UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
let index = index.slice(2, 4);
let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();

assert_eq!(
index,
&UInt32Array::from(vec![Some(2), Some(3), None, None])
);

test_take_primitive_arrays_non_null::<Int64Type>(
vec![0, 10, 20, 30, 40, 50],
&index,
None,
vec![Some(20), Some(30), None, None],
)
.unwrap();
}

#[test]
fn test_take_primitive_nullable_indices_nullable_values_with_offset() {
let index =
UInt32Array::from(vec![Some(0), Some(1), Some(2), Some(3), None, None]);
let index = index.slice(2, 4);
let index = index.as_any().downcast_ref::<UInt32Array>().unwrap();

assert_eq!(
index,
&UInt32Array::from(vec![Some(2), Some(3), None, None])
);

test_take_primitive_arrays::<Int64Type>(
vec![None, None, Some(20), Some(30), Some(40), Some(50)],
&index,
None,
vec![Some(20), Some(30), None, None],
)
.unwrap();
}

#[test]
fn test_take_primitive() {
let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
Expand Down Expand Up @@ -1100,7 +1166,7 @@ mod tests {
}

#[test]
fn test_take_primitive_bool() {
fn test_take_bool() {
let index = UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2)]);
// boolean
test_take_boolean_arrays(
Expand All @@ -1111,6 +1177,25 @@ mod tests {
);
}

#[test]
fn test_take_bool_with_offset() {
let index =
UInt32Array::from(vec![Some(3), None, Some(1), Some(3), Some(2), None]);
let index = index.slice(2, 4);
let index = index
.as_any()
.downcast_ref::<PrimitiveArray<UInt32Type>>()
.unwrap();

// boolean
test_take_boolean_arrays(
vec![Some(false), None, Some(true), Some(false), None],
&index,
None,
vec![None, Some(false), Some(true), None],
);
}

fn _test_take_string<'a, K: 'static>()
where
K: Array + PartialEq + From<Vec<Option<&'a str>>>,
Expand Down

0 comments on commit 2423fad

Please sign in to comment.