From b66a8c73646efeaa81f50a85a9e7796736af4ec2 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Tue, 25 Apr 2023 17:30:23 +0100 Subject: [PATCH] Update other constructors --- arrow-array/src/array/struct_array.rs | 137 +++++--------------------- 1 file changed, 27 insertions(+), 110 deletions(-) diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs index 9f449307f81c..a18f38c082c9 100644 --- a/arrow-array/src/array/struct_array.rs +++ b/arrow-array/src/array/struct_array.rs @@ -16,7 +16,7 @@ // under the License. use crate::{make_array, new_null_array, Array, ArrayRef, RecordBatch}; -use arrow_buffer::{buffer_bin_or, BooleanBuffer, Buffer, NullBuffer}; +use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field, Fields, SchemaBuilder}; use std::sync::Arc; @@ -309,66 +309,18 @@ impl TryFrom> for StructArray { type Error = ArrowError; /// builds a StructArray from a vector of names and arrays. - /// This errors if the values have a different length. - /// An entry is set to Null when all values are null. fn try_from(values: Vec<(&str, ArrayRef)>) -> Result { - let values_len = values.len(); - - // these will be populated - let mut fields = Vec::with_capacity(values_len); - let mut child_data = Vec::with_capacity(values_len); - - // len: the size of the arrays. - let mut len: Option = None; - // null: the null mask of the arrays. - let mut null: Option = None; - for (field_name, array) in values { - let child_datum = array.to_data(); - let child_datum_len = child_datum.len(); - if let Some(len) = len { - if len != child_datum_len { - return Err(ArrowError::InvalidArgumentError( - format!("Array of field \"{field_name}\" has length {child_datum_len}, but previous elements have length {len}. - All arrays in every entry in a struct array must have the same length.") - )); - } - } else { - len = Some(child_datum_len) - } - fields.push(Arc::new(Field::new( - field_name, - array.data_type().clone(), - child_datum.nulls().is_some(), - ))); - - if let Some(child_nulls) = child_datum.nulls() { - null = Some(if let Some(null_buffer) = &null { - buffer_bin_or( - null_buffer, - 0, - child_nulls.buffer(), - child_nulls.offset(), - child_datum_len, - ) - } else { - child_nulls.inner().sliced() - }); - } else if null.is_some() { - // when one of the fields has no nulls, then there is no null in the array - null = None; - } - child_data.push(child_datum); - } - let len = len.unwrap(); - - let builder = ArrayData::builder(DataType::Struct(fields.into())) - .len(len) - .null_bit_buffer(null) - .child_data(child_data); - - let array_data = unsafe { builder.build_unchecked() }; - - Ok(StructArray::from(array_data)) + let (schema, arrays): (SchemaBuilder, _) = values + .into_iter() + .map(|(name, array)| { + ( + Field::new(name, array.data_type().clone(), array.nulls().is_some()), + array, + ) + }) + .unzip(); + + StructArray::try_new(schema.finish().fields, arrays, None) } } @@ -429,38 +381,8 @@ impl Array for StructArray { impl From> for StructArray { fn from(v: Vec<(Field, ArrayRef)>) -> Self { - let iter = v.into_iter(); - let capacity = iter.size_hint().0; - - let mut len = None; - let mut schema = SchemaBuilder::with_capacity(capacity); - let mut child_data = Vec::with_capacity(capacity); - for (field, array) in iter { - // Check the length of the child arrays - assert_eq!( - *len.get_or_insert(array.len()), - array.len(), - "all child arrays of a StructArray must have the same length" - ); - // Check data types of child arrays - assert_eq!( - field.data_type(), - array.data_type(), - "the field data types must match the array data in a StructArray" - ); - schema.push(field); - child_data.push(array.to_data()); - } - let field_types = schema.finish().fields; - let array_data = ArrayData::builder(DataType::Struct(field_types)) - .child_data(child_data) - .len(len.unwrap_or_default()); - let array_data = unsafe { array_data.build_unchecked() }; - - // We must validate nullability - array_data.validate_nulls().unwrap(); - - Self::from(array_data) + let (schema, arrays): (SchemaBuilder, _) = v.into_iter().unzip(); + StructArray::new(schema.finish().fields, arrays, None) } } @@ -611,12 +533,7 @@ mod tests { let struct_data = arr.into_data(); assert_eq!(4, struct_data.len()); - assert_eq!(1, struct_data.null_count()); - assert_eq!( - // 00001011 - &[11_u8], - struct_data.nulls().unwrap().validity() - ); + assert_eq!(0, struct_data.null_count()); let expected_string_data = ArrayData::builder(DataType::Utf8) .len(4) @@ -648,20 +565,20 @@ mod tests { let ints: ArrayRef = Arc::new(Int32Array::from(vec![Some(1), Some(2), None, Some(4)])); - let arr = - StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]); + let err = + StructArray::try_from(vec![("f1", strings.clone()), ("f2", ints.clone())]) + .unwrap_err() + .to_string(); - match arr { - Err(ArrowError::InvalidArgumentError(e)) => { - assert!(e.starts_with("Array of field \"f2\" has length 4, but previous elements have length 3.")); - } - _ => panic!("This test got an unexpected error type"), - }; + assert_eq!( + err, + "Invalid argument error: Incorrect array length for StructArray field \"f2\", expected 3 got 4" + ) } #[test] #[should_panic( - expected = "the field data types must match the array data in a StructArray" + expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean" )] fn test_struct_array_from_mismatched_types_single() { drop(StructArray::from(vec![( @@ -673,7 +590,7 @@ mod tests { #[test] #[should_panic( - expected = "the field data types must match the array data in a StructArray" + expected = "Incorrect datatype for StructArray field \\\"b\\\", expected Int16 got Boolean" )] fn test_struct_array_from_mismatched_types_multiple() { drop(StructArray::from(vec![ @@ -778,7 +695,7 @@ mod tests { #[test] #[should_panic( - expected = "all child arrays of a StructArray must have the same length" + expected = "Incorrect array length for StructArray field \\\"c\\\", expected 1 got 2" )] fn test_invalid_struct_child_array_lengths() { drop(StructArray::from(vec![ @@ -801,7 +718,7 @@ mod tests { #[test] #[should_panic( - expected = "non-nullable child of type Int32 contains nulls not present in parent Struct" + expected = "Found unmasked nulls for non-nullable StructArray field \\\"c\\\"" )] fn test_struct_array_from_mismatched_nullability() { drop(StructArray::from(vec![(