diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 10f052b90923..1f302c750916 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -31,7 +31,6 @@ use crate::cast::{ use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err}; use crate::hash_utils::create_hashes; use crate::utils::{array_into_large_list_array, array_into_list_array}; -use arrow::buffer::{NullBuffer, OffsetBuffer}; use arrow::compute::kernels::numeric::*; use arrow::datatypes::{i256, Fields, SchemaBuilder}; use arrow::util::display::{ArrayFormatter, FormatOptions}; @@ -39,12 +38,11 @@ use arrow::{ array::*, compute::kernels::cast::{cast_with_options, CastOptions}, datatypes::{ - ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, - Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, IntervalDayTimeType, - IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType, TimeUnit, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, - DECIMAL128_MAX_PRECISION, + ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type, + Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType, + IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, + UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION, }, }; use arrow_array::cast::as_list_array; @@ -1368,103 +1366,36 @@ impl ScalarValue { }}; } - macro_rules! build_array_list_primitive { - ($ARRAY_TY:ident, $SCALAR_TY:ident, $NATIVE_TYPE:ident, $LIST_TY:ident, $SCALAR_LIST:pat) => {{ - Ok::(Arc::new($LIST_TY::from_iter_primitive::<$ARRAY_TY, _, _>( - scalars.into_iter().map(|x| match x{ - ScalarValue::List(arr) if matches!(x, $SCALAR_LIST) => { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - if list_arr.is_null(0) { - Ok(None) - } else { - let primitive_arr = - list_arr.values().as_primitive::<$ARRAY_TY>(); - Ok(Some( - primitive_arr.into_iter().collect::>>(), - )) - } - } - ScalarValue::LargeList(arr) if matches!(x, $SCALAR_LIST) =>{ - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_large_list_array(&arr); - if list_arr.is_null(0) { - Ok(None) - } else { - let primitive_arr = - list_arr.values().as_primitive::<$ARRAY_TY>(); - Ok(Some( - primitive_arr.into_iter().collect::>>(), - )) - } - } - sv => _internal_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected {:?}, got {:?}", - data_type, sv - ), - }) - .collect::>>()?, - ))) - }}; - } - - macro_rules! build_array_list_string { - ($BUILDER:ident, $STRING_ARRAY:ident,$LIST_BUILDER:ident,$SCALAR_LIST:pat) => {{ - let mut builder = $LIST_BUILDER::new($BUILDER::new()); - for scalar in scalars.into_iter() { - match scalar { - ScalarValue::List(arr) if matches!(scalar, $SCALAR_LIST) => { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - - if list_arr.is_null(0) { - builder.append(false); - continue; - } - - let string_arr = $STRING_ARRAY(list_arr.values()); - - for v in string_arr.iter() { - if let Some(v) = v { - builder.values().append_value(v); - } else { - builder.values().append_null(); - } - } - builder.append(true); - } - ScalarValue::LargeList(arr) if matches!(scalar, $SCALAR_LIST) => { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_large_list_array(&arr); - - if list_arr.is_null(0) { - builder.append(false); - continue; - } - - let string_arr = $STRING_ARRAY(list_arr.values()); - - for v in string_arr.iter() { - if let Some(v) = v { - builder.values().append_value(v); - } else { - builder.values().append_null(); - } - } - builder.append(true); - } - sv => { - return _internal_err!( - "Inconsistent types in ScalarValue::iter_to_array. \ - Expected List, got {:?}", - sv - ) - } - } + fn build_list_array( + scalars: impl IntoIterator, + ) -> Result { + let arrays = scalars + .into_iter() + .map(|s| s.to_array()) + .collect::>>()?; + + let capacity = Capacities::Array(arrays.iter().map(|arr| arr.len()).sum()); + // ScalarValue::List contains a single element ListArray. + let nulls = arrays + .iter() + .map(|arr| arr.is_null(0)) + .collect::>(); + let arrays_data = arrays.iter().map(|arr| arr.to_data()).collect::>(); + + let arrays_ref = arrays_data.iter().collect::>(); + let mut mutable = + MutableArrayData::with_capacities(arrays_ref, true, capacity); + + // ScalarValue::List contains a single element ListArray. + for (index, is_null) in (0..arrays.len()).zip(nulls.into_iter()) { + if is_null { + mutable.extend_nulls(1) + } else { + mutable.extend(index, 0, 1); } - Arc::new(builder.finish()) - }}; + } + let data = mutable.freeze(); + Ok(arrow_array::make_array(data)) } let array: ArrayRef = match &data_type { @@ -1541,228 +1472,7 @@ impl ScalarValue { DataType::Interval(IntervalUnit::MonthDayNano) => { build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano) } - DataType::List(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!( - Int8Type, - Int8, - i8, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!( - Int16Type, - Int16, - i16, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!( - Int32Type, - Int32, - i32, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!( - Int64Type, - Int64, - i64, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!( - UInt8Type, - UInt8, - u8, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!( - UInt16Type, - UInt16, - u16, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!( - UInt32Type, - UInt32, - u32, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!( - UInt64Type, - UInt64, - u64, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!( - Float32Type, - Float32, - f32, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!( - Float64Type, - Float64, - f64, - ListArray, - ScalarValue::List(_) - )? - } - DataType::List(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!( - StringBuilder, - as_string_array, - ListBuilder, - ScalarValue::List(_) - ) - } - DataType::List(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!( - LargeStringBuilder, - as_largestring_array, - ListBuilder, - ScalarValue::List(_) - ) - } - DataType::List(_) => { - // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = ScalarValue::iter_to_array_list(scalars)?; - Arc::new(list_array) - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Int8 => { - build_array_list_primitive!( - Int8Type, - Int8, - i8, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Int16 => { - build_array_list_primitive!( - Int16Type, - Int16, - i16, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Int32 => { - build_array_list_primitive!( - Int32Type, - Int32, - i32, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Int64 => { - build_array_list_primitive!( - Int64Type, - Int64, - i64, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::UInt8 => { - build_array_list_primitive!( - UInt8Type, - UInt8, - u8, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::UInt16 => { - build_array_list_primitive!( - UInt16Type, - UInt16, - u16, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::UInt32 => { - build_array_list_primitive!( - UInt32Type, - UInt32, - u32, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::UInt64 => { - build_array_list_primitive!( - UInt64Type, - UInt64, - u64, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Float32 => { - build_array_list_primitive!( - Float32Type, - Float32, - f32, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Float64 => { - build_array_list_primitive!( - Float64Type, - Float64, - f64, - LargeListArray, - ScalarValue::LargeList(_) - )? - } - DataType::LargeList(fields) if fields.data_type() == &DataType::Utf8 => { - build_array_list_string!( - StringBuilder, - as_string_array, - LargeListBuilder, - ScalarValue::LargeList(_) - ) - } - DataType::LargeList(fields) if fields.data_type() == &DataType::LargeUtf8 => { - build_array_list_string!( - LargeStringBuilder, - as_largestring_array, - LargeListBuilder, - ScalarValue::LargeList(_) - ) - } - DataType::LargeList(_) => { - // Fallback case handling homogeneous lists with any ScalarValue element type - let list_array = ScalarValue::iter_to_large_array_list(scalars)?; - Arc::new(list_array) - } + DataType::List(_) | DataType::LargeList(_) => build_list_array(scalars)?, DataType::Struct(fields) => { // Initialize a Vector to store the ScalarValues for each column let mut columns: Vec> = @@ -1942,116 +1652,6 @@ impl ScalarValue { Ok(array) } - /// This function build ListArray with nulls with nulls buffer. - fn iter_to_array_list( - scalars: impl IntoIterator, - ) -> Result { - let mut elements: Vec = vec![]; - let mut valid = BooleanBufferBuilder::new(0); - let mut offsets = vec![]; - - for scalar in scalars { - if let ScalarValue::List(arr) = scalar { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - - if list_arr.is_null(0) { - // Repeat previous offset index - offsets.push(0); - - // Element is null - valid.append(false); - } else { - let arr = list_arr.values().to_owned(); - offsets.push(arr.len()); - elements.push(arr); - - // Element is valid - valid.append(true); - } - } else { - return _internal_err!( - "Expected ScalarValue::List element. Received {scalar:?}" - ); - } - } - - // Concatenate element arrays to create single flat array - let element_arrays: Vec<&dyn Array> = - elements.iter().map(|a| a.as_ref()).collect(); - - let flat_array = match arrow::compute::concat(&element_arrays) { - Ok(flat_array) => flat_array, - Err(err) => return Err(DataFusionError::ArrowError(err)), - }; - - let buffer = valid.finish(); - - let list_array = ListArray::new( - Arc::new(Field::new("item", flat_array.data_type().clone(), true)), - OffsetBuffer::from_lengths(offsets), - flat_array, - Some(NullBuffer::new(buffer)), - ); - - Ok(list_array) - } - - /// This function build LargeListArray with nulls with nulls buffer. - fn iter_to_large_array_list( - scalars: impl IntoIterator, - ) -> Result { - let mut elements: Vec = vec![]; - let mut valid = BooleanBufferBuilder::new(0); - let mut offsets = vec![]; - - for scalar in scalars { - if let ScalarValue::List(arr) = scalar { - // `ScalarValue::List` contains a single element `ListArray`. - let list_arr = as_list_array(&arr); - - if list_arr.is_null(0) { - // Repeat previous offset index - offsets.push(0); - - // Element is null - valid.append(false); - } else { - let arr = list_arr.values().to_owned(); - offsets.push(arr.len()); - elements.push(arr); - - // Element is valid - valid.append(true); - } - } else { - return _internal_err!( - "Expected ScalarValue::List element. Received {scalar:?}" - ); - } - } - - // Concatenate element arrays to create single flat array - let element_arrays: Vec<&dyn Array> = - elements.iter().map(|a| a.as_ref()).collect(); - - let flat_array = match arrow::compute::concat(&element_arrays) { - Ok(flat_array) => flat_array, - Err(err) => return Err(DataFusionError::ArrowError(err)), - }; - - let buffer = valid.finish(); - - let list_array = LargeListArray::new( - Arc::new(Field::new("item", flat_array.data_type().clone(), true)), - OffsetBuffer::from_lengths(offsets), - flat_array, - Some(NullBuffer::new(buffer)), - ); - - Ok(list_array) - } - fn build_decimal_array( value: Option, precision: u8, @@ -3520,21 +3120,23 @@ impl ScalarType for TimestampNanosecondType { #[cfg(test)] mod tests { + use super::*; + use std::cmp::Ordering; use std::sync::Arc; + use chrono::NaiveDate; + use rand::Rng; + + use arrow::buffer::OffsetBuffer; use arrow::compute::kernels; use arrow::compute::{concat, is_null}; use arrow::datatypes::ArrowPrimitiveType; use arrow::util::pretty::pretty_format_columns; use arrow_array::ArrowNumericType; - use chrono::NaiveDate; - use rand::Rng; use crate::cast::{as_string_array, as_uint32_array, as_uint64_array}; - use super::*; - #[test] fn test_to_array_of_size_for_list() { let arr = ListArray::from_iter_primitive::(vec![Some(vec![ @@ -3597,28 +3199,77 @@ mod tests { assert_eq!(result, &expected); } + fn build_list( + values: Vec>>>, + ) -> Vec { + values + .into_iter() + .map(|v| { + let arr = if v.is_some() { + Arc::new( + GenericListArray::::from_iter_primitive::( + vec![v], + ), + ) + } else if O::IS_LARGE { + new_null_array( + &DataType::LargeList(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + 1, + ) + } else { + new_null_array( + &DataType::List(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + 1, + ) + }; + + if O::IS_LARGE { + ScalarValue::LargeList(arr) + } else { + ScalarValue::List(arr) + } + }) + .collect() + } + #[test] fn iter_to_array_primitive_test() { - let scalars = vec![ - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]), - )), - ScalarValue::List(Arc::new( - ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]), - )), - ]; + // List[[1,2,3]], List[null], List[[4,5]] + let scalars = build_list::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); let array = ScalarValue::iter_to_array(scalars).unwrap(); let list_array = as_list_array(&array); + // List[[1,2,3], null, [4,5]] let expected = ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + assert_eq!(list_array, &expected); + + let scalars = build_list::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, + Some(vec![Some(4), Some(5)]), + ]); + + let array = ScalarValue::iter_to_array(scalars).unwrap(); + let list_array = as_large_list_array(&array); + let expected = LargeListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + None, Some(vec![Some(4), Some(5)]), ]); assert_eq!(list_array, &expected); @@ -5083,69 +4734,37 @@ mod tests { assert_eq!(array, &expected); } - #[test] - fn test_nested_lists() { - // Define inner list scalars - let a1 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(1), - Some(2), - Some(3), - ])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(4), - Some(5), - ])]); - let l1 = ListArray::new( - Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - )), - OffsetBuffer::::from_lengths([1, 1]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(6)])]); - let a2 = ListArray::from_iter_primitive::(vec![Some(vec![ - Some(7), - Some(8), - ])]); - let l2 = ListArray::new( - Arc::new(Field::new( - "item", - DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), - true, - )), - OffsetBuffer::::from_lengths([1, 1]), - arrow::compute::concat(&[&a1, &a2]).unwrap(), - None, - ); - - let a1 = - ListArray::from_iter_primitive::(vec![Some(vec![Some(9)])]); - let l3 = ListArray::new( + fn build_2d_list(data: Vec>) -> ListArray { + let a1 = ListArray::from_iter_primitive::(vec![Some(data)]); + ListArray::new( Arc::new(Field::new( "item", DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), true, )), OffsetBuffer::::from_lengths([1]), - arrow::compute::concat(&[&a1]).unwrap(), + Arc::new(a1), None, - ); + ) + } + + #[test] + fn test_nested_lists() { + // Define inner list scalars + let arr1 = build_2d_list(vec![Some(1), Some(2), Some(3)]); + let arr2 = build_2d_list(vec![Some(4), Some(5)]); + let arr3 = build_2d_list(vec![Some(6)]); let array = ScalarValue::iter_to_array(vec![ - ScalarValue::List(Arc::new(l1)), - ScalarValue::List(Arc::new(l2)), - ScalarValue::List(Arc::new(l3)), + ScalarValue::List(Arc::new(arr1)), + ScalarValue::List(Arc::new(arr2)), + ScalarValue::List(Arc::new(arr3)), ]) .unwrap(); let array = as_list_array(&array); // Construct expected array with array builders - let inner_builder = Int32Array::builder(8); + let inner_builder = Int32Array::builder(6); let middle_builder = ListBuilder::new(inner_builder); let mut outer_builder = ListBuilder::new(middle_builder); @@ -5153,6 +4772,7 @@ mod tests { outer_builder.values().values().append_value(2); outer_builder.values().values().append_value(3); outer_builder.values().append(true); + outer_builder.append(true); outer_builder.values().values().append_value(4); outer_builder.values().values().append_value(5); @@ -5161,14 +4781,6 @@ mod tests { outer_builder.values().values().append_value(6); outer_builder.values().append(true); - - outer_builder.values().values().append_value(7); - outer_builder.values().values().append_value(8); - outer_builder.values().append(true); - outer_builder.append(true); - - outer_builder.values().values().append_value(9); - outer_builder.values().append(true); outer_builder.append(true); let expected = outer_builder.finish();