Skip to content

Commit

Permalink
Add typed buffers to UnionArray (apache#3880)
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Mar 24, 2023
1 parent 33bbaa5 commit 1cbf082
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 115 deletions.
172 changes: 89 additions & 83 deletions arrow-array/src/array/union_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use crate::{make_array, Array, ArrayRef};
use arrow_buffer::buffer::NullBuffer;
use arrow_buffer::Buffer;
use arrow_buffer::{Buffer, ScalarBuffer};
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field, UnionMode};
/// Contains the `UnionArray` type.
Expand Down Expand Up @@ -109,6 +109,8 @@ use std::sync::Arc;
#[derive(Clone)]
pub struct UnionArray {
data: ArrayData,
type_ids: ScalarBuffer<i8>,
offsets: Option<ScalarBuffer<i32>>,
boxed_fields: Vec<Option<ArrayRef>>,
}

Expand Down Expand Up @@ -241,21 +243,29 @@ impl UnionArray {
///
/// Panics if `index` is greater than the length of the array.
pub fn type_id(&self, index: usize) -> i8 {
assert!(index < self.len());
self.data().buffers()[0].as_slice()[self.offset() + index] as i8
self.type_ids[index]
}

/// Returns the `type_ids` for this array
pub fn type_ids(&self) -> &ScalarBuffer<i8> {
&self.type_ids
}

/// Returns the `offsets` buffer if this is a dense array
pub fn offsets(&self) -> Option<&ScalarBuffer<i32>> {
self.offsets.as_ref()
}

/// Returns the offset into the underlying values array for the array slot at `index`.
///
/// # Panics
///
/// Panics if `index` is greater than the length of the array.
pub fn value_offset(&self, index: usize) -> i32 {
pub fn value_offset(&self, index: usize) -> usize {
assert!(index < self.len());
if self.is_dense() {
self.data().buffers()[1].typed_data::<i32>()[self.offset() + index]
} else {
(self.offset() + index) as i32
match &self.offsets {
Some(offsets) => offsets[index] as usize,
None => self.offset() + index,
}
}

Expand All @@ -264,7 +274,7 @@ impl UnionArray {
/// Panics if index `i` is out of bounds
pub fn value(&self, i: usize) -> ArrayRef {
let type_id = self.type_id(i);
let value_offset = self.value_offset(i) as usize;
let value_offset = self.value_offset(i);
let child = self.child(type_id);
child.slice(value_offset, 1)
}
Expand All @@ -291,16 +301,36 @@ impl UnionArray {

impl From<ArrayData> for UnionArray {
fn from(data: ArrayData) -> Self {
let field_ids = match data.data_type() {
DataType::Union(_, ids, _) => ids,
let (field_ids, mode) = match data.data_type() {
DataType::Union(_, ids, mode) => (ids, *mode),
d => panic!("UnionArray expected ArrayData with type Union got {d}"),
};
let (type_ids, offsets) = match mode {
UnionMode::Sparse => (
ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
None,
),
UnionMode::Dense => (
ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()),
Some(ScalarBuffer::new(
data.buffers()[1].clone(),
data.offset(),
data.len(),
)),
),
};

let max_id = field_ids.iter().copied().max().unwrap_or_default() as usize;
let mut boxed_fields = vec![None; max_id + 1];
for (cd, field_id) in data.child_data().iter().zip(field_ids) {
boxed_fields[*field_id as usize] = Some(make_array(cd.clone()));
}
Self { data, boxed_fields }
Self {
data,
type_ids,
offsets,
boxed_fields,
}
}
}

Expand Down Expand Up @@ -364,16 +394,16 @@ impl std::fmt::Debug for UnionArray {
writeln!(f, "{header}")?;

writeln!(f, "-- type id buffer:")?;
writeln!(f, "{:?}", self.data().buffers()[0])?;
writeln!(f, "{:?}", self.type_ids)?;

let (fields, ids, mode) = match self.data_type() {
DataType::Union(f, ids, mode) => (f, ids, mode),
let (fields, ids) = match self.data_type() {
DataType::Union(f, ids, _) => (f, ids),
_ => unreachable!(),
};

if mode == &UnionMode::Dense {
if let Some(offsets) = &self.offsets {
writeln!(f, "-- offsets buffer:")?;
writeln!(f, "{:?}", self.data().buffers()[1])?;
writeln!(f, "{:?}", offsets)?;
}

assert_eq!(fields.len(), ids.len());
Expand Down Expand Up @@ -418,39 +448,33 @@ mod tests {
let union = builder.build().unwrap();

let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1];
let expected_value_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1];
let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1];
let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];

// Check type ids
assert_eq!(
*union.data().buffers()[0],
Buffer::from_slice_ref(&expected_type_ids)
);
assert_eq!(*union.type_ids(), expected_type_ids);
for (i, id) in expected_type_ids.iter().enumerate() {
assert_eq!(id, &union.type_id(i));
}

// Check offsets
assert_eq!(
*union.data().buffers()[1],
Buffer::from_slice_ref(&expected_value_offsets)
);
for (i, id) in expected_value_offsets.iter().enumerate() {
assert_eq!(&union.value_offset(i), id);
assert_eq!(*union.offsets().unwrap(), expected_offsets);
for (i, id) in expected_offsets.iter().enumerate() {
assert_eq!(union.value_offset(i), *id as usize);
}

// Check data
assert_eq!(
*union.data().child_data()[0].buffers()[0],
Buffer::from_slice_ref([1_i32, 4, 6])
*union.child(0).as_primitive::<Int32Type>().values(),
[1_i32, 4, 6]
);
assert_eq!(
*union.data().child_data()[1].buffers()[0],
Buffer::from_slice_ref([2_i32, 7])
*union.child(1).as_primitive::<Int32Type>().values(),
[2_i32, 7]
);
assert_eq!(
*union.data().child_data()[2].buffers()[0],
Buffer::from_slice_ref([3_i32, 5]),
*union.child(2).as_primitive::<Int32Type>().values(),
[3_i32, 5]
);

assert_eq!(expected_array_values.len(), union.len());
Expand All @@ -470,7 +494,7 @@ mod tests {
let mut builder = UnionBuilder::new_dense();

let expected_type_ids = vec![0_i8; 1024];
let expected_value_offsets: Vec<_> = (0..1024).collect();
let expected_offsets: Vec<_> = (0..1024).collect();
let expected_array_values: Vec<_> = (1..=1024).collect();

expected_array_values
Expand All @@ -480,27 +504,21 @@ mod tests {
let union = builder.build().unwrap();

// Check type ids
assert_eq!(
*union.data().buffers()[0],
Buffer::from_slice_ref(&expected_type_ids)
);
assert_eq!(*union.type_ids(), expected_type_ids);
for (i, id) in expected_type_ids.iter().enumerate() {
assert_eq!(id, &union.type_id(i));
}

// Check offsets
assert_eq!(
*union.data().buffers()[1],
Buffer::from_slice_ref(&expected_value_offsets)
);
for (i, id) in expected_value_offsets.iter().enumerate() {
assert_eq!(&union.value_offset(i), id);
assert_eq!(*union.offsets().unwrap(), expected_offsets);
for (i, id) in expected_offsets.iter().enumerate() {
assert_eq!(union.value_offset(i), *id as usize);
}

for (i, expected_value) in expected_array_values.iter().enumerate() {
assert!(!union.is_null(i));
let slot = union.value(i);
let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
let slot = slot.as_primitive::<Int32Type>();
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(expected_value, &value);
Expand Down Expand Up @@ -649,10 +667,10 @@ mod tests {
let float_array = Float64Array::from(vec![10.0]);

let type_ids = [1_i8, 0, 0, 2, 0, 1];
let value_offsets = [0_i32, 0, 1, 0, 2, 1];
let offsets = [0_i32, 0, 1, 0, 2, 1];

let type_id_buffer = Buffer::from_slice_ref(type_ids);
let value_offsets_buffer = Buffer::from_slice_ref(value_offsets);
let value_offsets_buffer = Buffer::from_slice_ref(offsets);

let children: Vec<(Field, Arc<dyn Array>)> = vec![
(
Expand All @@ -674,18 +692,15 @@ mod tests {
.unwrap();

// Check type ids
assert_eq!(Buffer::from_slice_ref(type_ids), *array.data().buffers()[0]);
assert_eq!(*array.type_ids(), type_ids);
for (i, id) in type_ids.iter().enumerate() {
assert_eq!(id, &array.type_id(i));
}

// Check offsets
assert_eq!(
Buffer::from_slice_ref(value_offsets),
*array.data().buffers()[1]
);
for (i, id) in value_offsets.iter().enumerate() {
assert_eq!(id, &array.value_offset(i));
assert_eq!(*array.offsets().unwrap(), offsets);
for (i, id) in offsets.iter().enumerate() {
assert_eq!(*id as usize, array.value_offset(i));
}

// Check values
Expand Down Expand Up @@ -748,29 +763,26 @@ mod tests {
let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7];

// Check type ids
assert_eq!(
Buffer::from_slice_ref(&expected_type_ids),
*union.data().buffers()[0]
);
assert_eq!(*union.type_ids(), expected_type_ids);
for (i, id) in expected_type_ids.iter().enumerate() {
assert_eq!(id, &union.type_id(i));
}

// Check offsets, sparse union should only have a single buffer
assert_eq!(union.data().buffers().len(), 1);
assert!(union.offsets().is_none());

// Check data
assert_eq!(
*union.data().child_data()[0].buffers()[0],
Buffer::from_slice_ref([1_i32, 0, 0, 4, 0, 6, 0]),
*union.child(0).as_primitive::<Int32Type>().values(),
[1_i32, 0, 0, 4, 0, 6, 0],
);
assert_eq!(
Buffer::from_slice_ref([0_i32, 2_i32, 0, 0, 0, 0, 7]),
*union.data().child_data()[1].buffers()[0]
*union.child(1).as_primitive::<Int32Type>().values(),
[0_i32, 2_i32, 0, 0, 0, 0, 7]
);
assert_eq!(
Buffer::from_slice_ref([0_i32, 0, 3_i32, 0, 5, 0, 0]),
*union.data().child_data()[2].buffers()[0]
*union.child(2).as_primitive::<Int32Type>().values(),
[0_i32, 0, 3_i32, 0, 5, 0, 0]
);

assert_eq!(expected_array_values.len(), union.len());
Expand All @@ -797,16 +809,13 @@ mod tests {
let expected_type_ids = vec![0_i8, 1, 0, 1, 0];

// Check type ids
assert_eq!(
Buffer::from_slice_ref(&expected_type_ids),
*union.data().buffers()[0]
);
assert_eq!(*union.type_ids(), expected_type_ids);
for (i, id) in expected_type_ids.iter().enumerate() {
assert_eq!(id, &union.type_id(i));
}

// Check offsets, sparse union should only have a single buffer, i.e. no offsets
assert_eq!(union.data().buffers().len(), 1);
assert!(union.offsets().is_none());

for i in 0..union.len() {
let slot = union.value(i);
Expand Down Expand Up @@ -859,16 +868,13 @@ mod tests {
let expected_type_ids = vec![0_i8, 0, 1, 0];

// Check type ids
assert_eq!(
Buffer::from_slice_ref(&expected_type_ids),
*union.data().buffers()[0]
);
assert_eq!(*union.type_ids(), expected_type_ids);
for (i, id) in expected_type_ids.iter().enumerate() {
assert_eq!(id, &union.type_id(i));
}

// Check offsets, sparse union should only have a single buffer, i.e. no offsets
assert_eq!(union.data().buffers().len(), 1);
assert!(union.offsets().is_none());

for i in 0..union.len() {
let slot = union.value(i);
Expand Down Expand Up @@ -919,15 +925,15 @@ mod tests {
match i {
0 => assert!(slot.is_null(0)),
1 => {
let slot = slot.as_any().downcast_ref::<Float64Array>().unwrap();
let slot = slot.as_primitive::<Float64Type>();
assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
assert_eq!(value, 3_f64);
}
2 => assert!(slot.is_null(0)),
3 => {
let slot = slot.as_any().downcast_ref::<Int32Array>().unwrap();
let slot = slot.as_primitive::<Int32Type>();
assert!(!slot.is_null(0));
assert_eq!(slot.len(), 1);
let value = slot.value(0);
Expand Down Expand Up @@ -1012,18 +1018,18 @@ mod tests {
assert_eq!(union_slice.type_id(2), 1);

let slot = union_slice.value(0);
let array = slot.as_any().downcast_ref::<Int32Array>().unwrap();
let array = slot.as_primitive::<Int32Type>();
assert_eq!(array.len(), 1);
assert!(array.is_null(0));

let slot = union_slice.value(1);
let array = slot.as_any().downcast_ref::<Float64Array>().unwrap();
let array = slot.as_primitive::<Float64Type>();
assert_eq!(array.len(), 1);
assert!(array.is_valid(0));
assert_eq!(array.value(0), 3.0);

let slot = union_slice.value(2);
let array = slot.as_any().downcast_ref::<Float64Array>().unwrap();
let array = slot.as_primitive::<Float64Type>();
assert_eq!(array.len(), 1);
assert!(array.is_null(0));
}
Expand Down Expand Up @@ -1059,8 +1065,8 @@ mod tests {
let int_array = Int32Array::from(vec![5, 6, 4]);
let float_array = Float64Array::from(vec![10.0]);

let type_ids = Buffer::from_iter([4_i8, 8, 4, 8, 9, 4, 8]);
let value_offsets = Buffer::from_iter([0_i32, 0, 1, 1, 0, 2, 2]);
let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]);
let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]);

let data = ArrayData::builder(data_type)
.len(7)
Expand Down
Loading

0 comments on commit 1cbf082

Please sign in to comment.