diff --git a/arrow-array/src/array/union_array.rs b/arrow-array/src/array/union_array.rs index fe227226f77d..9b3134bb7070 100644 --- a/arrow-array/src/array/union_array.rs +++ b/arrow-array/src/array/union_array.rs @@ -17,7 +17,7 @@ use crate::{make_array, Array, ArrayRef}; use arrow_buffer::buffer::NullBuffer; -use arrow_buffer::Buffer; +use arrow_buffer::{Buffer, ScalarBuffer}; use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType, Field, UnionMode}; /// Contains the `UnionArray` type. @@ -109,6 +109,8 @@ use std::sync::Arc; #[derive(Clone)] pub struct UnionArray { data: ArrayData, + type_ids: ScalarBuffer, + offsets: Option>, boxed_fields: Vec>, } @@ -241,8 +243,17 @@ impl UnionArray { /// /// Panics if `index` is greater than the length of the array. pub fn type_id(&self, index: usize) -> i8 { - assert!(index < self.len()); - self.data().buffers()[0].as_slice()[self.offset() + index] as i8 + self.type_ids[index] + } + + /// Returns the `type_ids` for this array + pub fn type_ids(&self) -> &ScalarBuffer { + &self.type_ids + } + + /// Returns the `offsets` buffer if this is a dense array + pub fn offsets(&self) -> Option<&ScalarBuffer> { + self.offsets.as_ref() } /// Returns the offset into the underlying values array for the array slot at `index`. @@ -250,12 +261,11 @@ impl UnionArray { /// # Panics /// /// Panics if `index` is greater than the length of the array. - pub fn value_offset(&self, index: usize) -> i32 { + pub fn value_offset(&self, index: usize) -> usize { assert!(index < self.len()); - if self.is_dense() { - self.data().buffers()[1].typed_data::()[self.offset() + index] - } else { - (self.offset() + index) as i32 + match &self.offsets { + Some(offsets) => offsets[index] as usize, + None => self.offset() + index, } } @@ -264,7 +274,7 @@ impl UnionArray { /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> ArrayRef { let type_id = self.type_id(i); - let value_offset = self.value_offset(i) as usize; + let value_offset = self.value_offset(i); let child = self.child(type_id); child.slice(value_offset, 1) } @@ -291,16 +301,36 @@ impl UnionArray { impl From for UnionArray { fn from(data: ArrayData) -> Self { - let field_ids = match data.data_type() { - DataType::Union(_, ids, _) => ids, + let (field_ids, mode) = match data.data_type() { + DataType::Union(_, ids, mode) => (ids, *mode), d => panic!("UnionArray expected ArrayData with type Union got {d}"), }; + let (type_ids, offsets) = match mode { + UnionMode::Sparse => ( + ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()), + None, + ), + UnionMode::Dense => ( + ScalarBuffer::new(data.buffers()[0].clone(), data.offset(), data.len()), + Some(ScalarBuffer::new( + data.buffers()[1].clone(), + data.offset(), + data.len(), + )), + ), + }; + let max_id = field_ids.iter().copied().max().unwrap_or_default() as usize; let mut boxed_fields = vec![None; max_id + 1]; for (cd, field_id) in data.child_data().iter().zip(field_ids) { boxed_fields[*field_id as usize] = Some(make_array(cd.clone())); } - Self { data, boxed_fields } + Self { + data, + type_ids, + offsets, + boxed_fields, + } } } @@ -364,16 +394,16 @@ impl std::fmt::Debug for UnionArray { writeln!(f, "{header}")?; writeln!(f, "-- type id buffer:")?; - writeln!(f, "{:?}", self.data().buffers()[0])?; + writeln!(f, "{:?}", self.type_ids)?; - let (fields, ids, mode) = match self.data_type() { - DataType::Union(f, ids, mode) => (f, ids, mode), + let (fields, ids) = match self.data_type() { + DataType::Union(f, ids, _) => (f, ids), _ => unreachable!(), }; - if mode == &UnionMode::Dense { + if let Some(offsets) = &self.offsets { writeln!(f, "-- offsets buffer:")?; - writeln!(f, "{:?}", self.data().buffers()[1])?; + writeln!(f, "{:?}", offsets)?; } assert_eq!(fields.len(), ids.len()); @@ -418,39 +448,33 @@ mod tests { let union = builder.build().unwrap(); let expected_type_ids = vec![0_i8, 1, 2, 0, 2, 0, 1]; - let expected_value_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1]; + let expected_offsets = vec![0_i32, 0, 0, 1, 1, 2, 1]; let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7]; // Check type ids - assert_eq!( - *union.data().buffers()[0], - Buffer::from_slice_ref(&expected_type_ids) - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets - assert_eq!( - *union.data().buffers()[1], - Buffer::from_slice_ref(&expected_value_offsets) - ); - for (i, id) in expected_value_offsets.iter().enumerate() { - assert_eq!(&union.value_offset(i), id); + assert_eq!(*union.offsets().unwrap(), expected_offsets); + for (i, id) in expected_offsets.iter().enumerate() { + assert_eq!(union.value_offset(i), *id as usize); } // Check data assert_eq!( - *union.data().child_data()[0].buffers()[0], - Buffer::from_slice_ref([1_i32, 4, 6]) + *union.child(0).as_primitive::().values(), + [1_i32, 4, 6] ); assert_eq!( - *union.data().child_data()[1].buffers()[0], - Buffer::from_slice_ref([2_i32, 7]) + *union.child(1).as_primitive::().values(), + [2_i32, 7] ); assert_eq!( - *union.data().child_data()[2].buffers()[0], - Buffer::from_slice_ref([3_i32, 5]), + *union.child(2).as_primitive::().values(), + [3_i32, 5] ); assert_eq!(expected_array_values.len(), union.len()); @@ -470,7 +494,7 @@ mod tests { let mut builder = UnionBuilder::new_dense(); let expected_type_ids = vec![0_i8; 1024]; - let expected_value_offsets: Vec<_> = (0..1024).collect(); + let expected_offsets: Vec<_> = (0..1024).collect(); let expected_array_values: Vec<_> = (1..=1024).collect(); expected_array_values @@ -480,27 +504,21 @@ mod tests { let union = builder.build().unwrap(); // Check type ids - assert_eq!( - *union.data().buffers()[0], - Buffer::from_slice_ref(&expected_type_ids) - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets - assert_eq!( - *union.data().buffers()[1], - Buffer::from_slice_ref(&expected_value_offsets) - ); - for (i, id) in expected_value_offsets.iter().enumerate() { - assert_eq!(&union.value_offset(i), id); + assert_eq!(*union.offsets().unwrap(), expected_offsets); + for (i, id) in expected_offsets.iter().enumerate() { + assert_eq!(union.value_offset(i), *id as usize); } for (i, expected_value) in expected_array_values.iter().enumerate() { assert!(!union.is_null(i)); let slot = union.value(i); - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(expected_value, &value); @@ -649,10 +667,10 @@ mod tests { let float_array = Float64Array::from(vec![10.0]); let type_ids = [1_i8, 0, 0, 2, 0, 1]; - let value_offsets = [0_i32, 0, 1, 0, 2, 1]; + let offsets = [0_i32, 0, 1, 0, 2, 1]; let type_id_buffer = Buffer::from_slice_ref(type_ids); - let value_offsets_buffer = Buffer::from_slice_ref(value_offsets); + let value_offsets_buffer = Buffer::from_slice_ref(offsets); let children: Vec<(Field, Arc)> = vec![ ( @@ -674,18 +692,15 @@ mod tests { .unwrap(); // Check type ids - assert_eq!(Buffer::from_slice_ref(type_ids), *array.data().buffers()[0]); + assert_eq!(*array.type_ids(), type_ids); for (i, id) in type_ids.iter().enumerate() { assert_eq!(id, &array.type_id(i)); } // Check offsets - assert_eq!( - Buffer::from_slice_ref(value_offsets), - *array.data().buffers()[1] - ); - for (i, id) in value_offsets.iter().enumerate() { - assert_eq!(id, &array.value_offset(i)); + assert_eq!(*array.offsets().unwrap(), offsets); + for (i, id) in offsets.iter().enumerate() { + assert_eq!(*id as usize, array.value_offset(i)); } // Check values @@ -748,29 +763,26 @@ mod tests { let expected_array_values = [1_i32, 2, 3, 4, 5, 6, 7]; // Check type ids - assert_eq!( - Buffer::from_slice_ref(&expected_type_ids), - *union.data().buffers()[0] - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets, sparse union should only have a single buffer - assert_eq!(union.data().buffers().len(), 1); + assert!(union.offsets().is_none()); // Check data assert_eq!( - *union.data().child_data()[0].buffers()[0], - Buffer::from_slice_ref([1_i32, 0, 0, 4, 0, 6, 0]), + *union.child(0).as_primitive::().values(), + [1_i32, 0, 0, 4, 0, 6, 0], ); assert_eq!( - Buffer::from_slice_ref([0_i32, 2_i32, 0, 0, 0, 0, 7]), - *union.data().child_data()[1].buffers()[0] + *union.child(1).as_primitive::().values(), + [0_i32, 2_i32, 0, 0, 0, 0, 7] ); assert_eq!( - Buffer::from_slice_ref([0_i32, 0, 3_i32, 0, 5, 0, 0]), - *union.data().child_data()[2].buffers()[0] + *union.child(2).as_primitive::().values(), + [0_i32, 0, 3_i32, 0, 5, 0, 0] ); assert_eq!(expected_array_values.len(), union.len()); @@ -797,16 +809,13 @@ mod tests { let expected_type_ids = vec![0_i8, 1, 0, 1, 0]; // Check type ids - assert_eq!( - Buffer::from_slice_ref(&expected_type_ids), - *union.data().buffers()[0] - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets, sparse union should only have a single buffer, i.e. no offsets - assert_eq!(union.data().buffers().len(), 1); + assert!(union.offsets().is_none()); for i in 0..union.len() { let slot = union.value(i); @@ -859,16 +868,13 @@ mod tests { let expected_type_ids = vec![0_i8, 0, 1, 0]; // Check type ids - assert_eq!( - Buffer::from_slice_ref(&expected_type_ids), - *union.data().buffers()[0] - ); + assert_eq!(*union.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &union.type_id(i)); } // Check offsets, sparse union should only have a single buffer, i.e. no offsets - assert_eq!(union.data().buffers().len(), 1); + assert!(union.offsets().is_none()); for i in 0..union.len() { let slot = union.value(i); @@ -919,7 +925,7 @@ mod tests { match i { 0 => assert!(slot.is_null(0)), 1 => { - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); @@ -927,7 +933,7 @@ mod tests { } 2 => assert!(slot.is_null(0)), 3 => { - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); @@ -1012,18 +1018,18 @@ mod tests { assert_eq!(union_slice.type_id(2), 1); let slot = union_slice.value(0); - let array = slot.as_any().downcast_ref::().unwrap(); + let array = slot.as_primitive::(); assert_eq!(array.len(), 1); assert!(array.is_null(0)); let slot = union_slice.value(1); - let array = slot.as_any().downcast_ref::().unwrap(); + let array = slot.as_primitive::(); assert_eq!(array.len(), 1); assert!(array.is_valid(0)); assert_eq!(array.value(0), 3.0); let slot = union_slice.value(2); - let array = slot.as_any().downcast_ref::().unwrap(); + let array = slot.as_primitive::(); assert_eq!(array.len(), 1); assert!(array.is_null(0)); } @@ -1059,8 +1065,8 @@ mod tests { let int_array = Int32Array::from(vec![5, 6, 4]); let float_array = Float64Array::from(vec![10.0]); - let type_ids = Buffer::from_iter([4_i8, 8, 4, 8, 9, 4, 8]); - let value_offsets = Buffer::from_iter([0_i32, 0, 1, 1, 0, 2, 2]); + let type_ids = Buffer::from_vec(vec![4_i8, 8, 4, 8, 9, 4, 8]); + let value_offsets = Buffer::from_vec(vec![0_i32, 0, 1, 1, 0, 2, 2]); let data = ArrayData::builder(data_type) .len(7) diff --git a/arrow-array/src/builder/union_builder.rs b/arrow-array/src/builder/union_builder.rs index 28fb7e5d999a..8ca303da8cb4 100644 --- a/arrow-array/src/builder/union_builder.rs +++ b/arrow-array/src/builder/union_builder.rs @@ -113,13 +113,13 @@ impl FieldData { /// builder.append::("a", 4).unwrap(); /// let union = builder.build().unwrap(); /// -/// assert_eq!(union.type_id(0), 0_i8); -/// assert_eq!(union.type_id(1), 1_i8); -/// assert_eq!(union.type_id(2), 0_i8); +/// assert_eq!(union.type_id(0), 0); +/// assert_eq!(union.type_id(1), 1); +/// assert_eq!(union.type_id(2), 0); /// -/// assert_eq!(union.value_offset(0), 0_i32); -/// assert_eq!(union.value_offset(1), 0_i32); -/// assert_eq!(union.value_offset(2), 1_i32); +/// assert_eq!(union.value_offset(0), 0); +/// assert_eq!(union.value_offset(1), 0); +/// assert_eq!(union.value_offset(2), 1); /// ``` /// /// Example: **Sparse Memory Layout** @@ -133,13 +133,13 @@ impl FieldData { /// builder.append::("a", 4).unwrap(); /// let union = builder.build().unwrap(); /// -/// assert_eq!(union.type_id(0), 0_i8); -/// assert_eq!(union.type_id(1), 1_i8); -/// assert_eq!(union.type_id(2), 0_i8); +/// assert_eq!(union.type_id(0), 0); +/// assert_eq!(union.type_id(1), 1); +/// assert_eq!(union.type_id(2), 0); /// -/// assert_eq!(union.value_offset(0), 0_i32); -/// assert_eq!(union.value_offset(1), 1_i32); -/// assert_eq!(union.value_offset(2), 2_i32); +/// assert_eq!(union.value_offset(0), 0); +/// assert_eq!(union.value_offset(1), 1); +/// assert_eq!(union.value_offset(2), 2); /// ``` #[derive(Debug)] pub struct UnionBuilder { diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index 9d0ed0b85fb6..07084f128acb 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -496,8 +496,9 @@ mod tests { use crate::compute::kernels; use crate::datatypes::{Field, Int8Type}; use arrow_array::builder::UnionBuilder; + use arrow_array::cast::AsArray; use arrow_array::types::{Float64Type, Int32Type}; - use arrow_array::{Float64Array, UnionArray}; + use arrow_array::UnionArray; use std::convert::TryFrom; use std::mem::ManuallyDrop; use std::ptr::addr_of_mut; @@ -1113,22 +1114,19 @@ mod tests { let expected_type_ids = vec![0_i8, 0, 1, 0]; // Check type ids - assert_eq!( - Buffer::from_slice_ref(&expected_type_ids), - *array.data().buffers()[0] - ); + assert_eq!(*array.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &array.type_id(i)); } // Check offsets, sparse union should only have a single buffer, i.e. no offsets - assert_eq!(array.data().buffers().len(), 1); + assert!(array.offsets().is_none()); for i in 0..array.len() { let slot = array.value(i); match i { 0 => { - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); @@ -1136,14 +1134,14 @@ mod tests { } 1 => assert!(slot.is_null(0)), 2 => { - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(value, 3_f64); } 3 => { - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); @@ -1170,28 +1168,23 @@ mod tests { // (simulate consumer) import it let data = ArrayData::try_from(array)?; - let array = make_array(data); - - let array = array.as_any().downcast_ref::().unwrap(); + let array = UnionArray::from(data); let expected_type_ids = vec![0_i8, 0, 1, 0]; // Check type ids - assert_eq!( - Buffer::from_slice_ref(&expected_type_ids), - *array.data().buffers()[0] - ); + assert_eq!(*array.type_ids(), expected_type_ids); for (i, id) in expected_type_ids.iter().enumerate() { assert_eq!(id, &array.type_id(i)); } - assert_eq!(array.data().buffers().len(), 2); + assert!(array.offsets().is_some()); for i in 0..array.len() { let slot = array.value(i); match i { 0 => { - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); @@ -1199,14 +1192,14 @@ mod tests { } 1 => assert!(slot.is_null(0)), 2 => { - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0); assert_eq!(value, 3_f64); } 3 => { - let slot = slot.as_any().downcast_ref::().unwrap(); + let slot = slot.as_primitive::(); assert!(!slot.is_null(0)); assert_eq!(slot.len(), 1); let value = slot.value(0);