Skip to content

Commit

Permalink
Update other constructors
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Apr 25, 2023
1 parent e2f3f06 commit b66a8c7
Showing 1 changed file with 27 additions and 110 deletions.
137 changes: 27 additions & 110 deletions arrow-array/src/array/struct_array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use crate::{make_array, new_null_array, Array, ArrayRef, RecordBatch};
use arrow_buffer::{buffer_bin_or, BooleanBuffer, Buffer, NullBuffer};
use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType, Field, Fields, SchemaBuilder};
use std::sync::Arc;
Expand Down Expand Up @@ -309,66 +309,18 @@ impl TryFrom<Vec<(&str, ArrayRef)>> for StructArray {
type Error = ArrowError;

/// builds a StructArray from a vector of names and arrays.
/// This errors if the values have a different length.
/// An entry is set to Null when all values are null.
fn try_from(values: Vec<(&str, ArrayRef)>) -> Result<Self, ArrowError> {
let values_len = values.len();

// these will be populated
let mut fields = Vec::with_capacity(values_len);
let mut child_data = Vec::with_capacity(values_len);

// len: the size of the arrays.
let mut len: Option<usize> = None;
// null: the null mask of the arrays.
let mut null: Option<Buffer> = None;
for (field_name, array) in values {
let child_datum = array.to_data();
let child_datum_len = child_datum.len();
if let Some(len) = len {
if len != child_datum_len {
return Err(ArrowError::InvalidArgumentError(
format!("Array of field \"{field_name}\" has length {child_datum_len}, but previous elements have length {len}.
All arrays in every entry in a struct array must have the same length.")
));
}
} else {
len = Some(child_datum_len)
}
fields.push(Arc::new(Field::new(
field_name,
array.data_type().clone(),
child_datum.nulls().is_some(),
)));

if let Some(child_nulls) = child_datum.nulls() {
null = Some(if let Some(null_buffer) = &null {
buffer_bin_or(
null_buffer,
0,
child_nulls.buffer(),
child_nulls.offset(),
child_datum_len,
)
} else {
child_nulls.inner().sliced()
});
} else if null.is_some() {
// when one of the fields has no nulls, then there is no null in the array
null = None;
}
child_data.push(child_datum);
}
let len = len.unwrap();

let builder = ArrayData::builder(DataType::Struct(fields.into()))
.len(len)
.null_bit_buffer(null)
.child_data(child_data);

let array_data = unsafe { builder.build_unchecked() };

Ok(StructArray::from(array_data))
let (schema, arrays): (SchemaBuilder, _) = values
.into_iter()
.map(|(name, array)| {
(
Field::new(name, array.data_type().clone(), array.nulls().is_some()),
array,
)
})
.unzip();

StructArray::try_new(schema.finish().fields, arrays, None)
}
}

Expand Down Expand Up @@ -429,38 +381,8 @@ impl Array for StructArray {

impl From<Vec<(Field, ArrayRef)>> for StructArray {
fn from(v: Vec<(Field, ArrayRef)>) -> Self {
let iter = v.into_iter();
let capacity = iter.size_hint().0;

let mut len = None;
let mut schema = SchemaBuilder::with_capacity(capacity);
let mut child_data = Vec::with_capacity(capacity);
for (field, array) in iter {
// Check the length of the child arrays
assert_eq!(
*len.get_or_insert(array.len()),
array.len(),
"all child arrays of a StructArray must have the same length"
);
// Check data types of child arrays
assert_eq!(
field.data_type(),
array.data_type(),
"the field data types must match the array data in a StructArray"
);
schema.push(field);
child_data.push(array.to_data());
}
let field_types = schema.finish().fields;
let array_data = ArrayData::builder(DataType::Struct(field_types))
.child_data(child_data)
.len(len.unwrap_or_default());
let array_data = unsafe { array_data.build_unchecked() };

// We must validate nullability
array_data.validate_nulls().unwrap();

Self::from(array_data)
let (schema, arrays): (SchemaBuilder, _) = v.into_iter().unzip();
StructArray::new(schema.finish().fields, arrays, None)
}
}

Expand Down Expand Up @@ -611,12 +533,7 @@ mod tests {

let struct_data = arr.into_data();
assert_eq!(4, struct_data.len());
assert_eq!(1, struct_data.null_count());
assert_eq!(
// 00001011
&[11_u8],
struct_data.nulls().unwrap().validity()
);
assert_eq!(0, struct_data.null_count());

let expected_string_data = ArrayData::builder(DataType::Utf8)
.len(4)
Expand Down Expand Up @@ -648,20 +565,20 @@ mod tests {
let ints: ArrayRef =
Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)]));

let arr =
StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]);
let err =
StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())])
.unwrap_err()
.to_string();

match arr {
Err(ArrowError::InvalidArgumentError(e)) => {
assert!(e.starts_with("Array of field \"f2\" has length 4, but previous elements have length 3."));
}
_ => panic!("This test got an unexpected error type"),
};
assert_eq!(
err,
"Invalid argument error: Incorrect array length for StructArray field \"f2\", expected 3 got 4"
)
}

#[test]
#[should_panic(
expected = "the field data types must match the array data in a StructArray"
expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean"
)]
fn test_struct_array_from_mismatched_types_single() {
drop(StructArray::from(vec![(
Expand All @@ -673,7 +590,7 @@ mod tests {

#[test]
#[should_panic(
expected = "the field data types must match the array data in a StructArray"
expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean"
)]
fn test_struct_array_from_mismatched_types_multiple() {
drop(StructArray::from(vec![
Expand Down Expand Up @@ -778,7 +695,7 @@ mod tests {

#[test]
#[should_panic(
expected = "all child arrays of a StructArray must have the same length"
expected = "Incorrect array length for StructArray field \\\"c\\\", expected 1 got 2"
)]
fn test_invalid_struct_child_array_lengths() {
drop(StructArray::from(vec![
Expand All @@ -801,7 +718,7 @@ mod tests {

#[test]
#[should_panic(
expected = "non-nullable child of type Int32 contains nulls not present in parent Struct"
expected = "Found unmasked nulls for non-nullable StructArray field \\\"c\\\""
)]
fn test_struct_array_from_mismatched_nullability() {
drop(StructArray::from(vec![(
Expand Down

0 comments on commit b66a8c7

Please sign in to comment.