Skip to content

Commit

Permalink
Use concat to simplify Nested Scalar creation (#9174)
Browse files Browse the repository at this point in the history
* replace with concat

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rewrite

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* remove map_err

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 authored Feb 10, 2024
1 parent ae88235 commit f97a208
Showing 1 changed file with 65 additions and 120 deletions.
185 changes: 65 additions & 120 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ use arrow::{
};
use arrow_array::cast::as_list_array;
use arrow_array::{ArrowNativeTypeOp, Scalar};
use arrow_buffer::{Buffer, NullBuffer};
use arrow_buffer::NullBuffer;

/// A dynamically typed, nullable single value, (the single-valued counter-part
/// to arrow's [`Array`])
Expand Down Expand Up @@ -1402,121 +1402,6 @@ impl ScalarValue {
}};
}

fn build_struct_array(
scalars: impl IntoIterator<Item = ScalarValue>,
) -> Result<ArrayRef> {
let arrays = scalars
.into_iter()
.map(|s| s.to_array())
.collect::<Result<Vec<_>>>()?;

let first_struct = arrays[0].as_struct_opt();
if first_struct.is_none() {
return _internal_err!(
"Inconsistent types in ScalarValue::iter_to_array. \
Expected ScalarValue::Struct, got {:?}",
arrays[0].clone()
);
}

let mut valid = BooleanBufferBuilder::new(arrays.len());

let first_struct = first_struct.unwrap();
valid.append(first_struct.is_valid(0));

let mut column_values: Vec<Vec<ScalarValue>> =
vec![Vec::with_capacity(arrays.len()); first_struct.num_columns()];

for (i, v) in first_struct.columns().iter().enumerate() {
// ScalarValue::Struct contains a single element in each column.
let sv = ScalarValue::try_from_array(v, 0)?;
column_values[i].push(sv);
}

for arr in arrays.iter().skip(1) {
if let Some(struct_array) = arr.as_struct_opt() {
valid.append(struct_array.is_valid(0));

for (i, v) in struct_array.columns().iter().enumerate() {
// ScalarValue::Struct contains a single element in each column.
let sv = ScalarValue::try_from_array(v, 0)?;
column_values[i].push(sv);
}
} else {
return _internal_err!(
"Inconsistent types in ScalarValue::iter_to_array. \
Expected ScalarValue::Struct, got {arr:?}"
);
}
}

let column_fields = first_struct.fields().to_vec();

let mut data = vec![];
for (field, values) in
column_fields.into_iter().zip(column_values.into_iter())
{
let field = field.to_owned();
let array = ScalarValue::iter_to_array(values.into_iter())?;
data.push((field, array));
}

let bool_buffer = valid.finish();
let buffer: Buffer = bool_buffer.values().into();
Ok(Arc::new(StructArray::from((data, buffer))))
}

fn build_list_array(
scalars: impl IntoIterator<Item = ScalarValue>,
) -> Result<ArrayRef> {
let arrays = scalars
.into_iter()
.map(|s| s.to_array())
.collect::<Result<Vec<_>>>()?;

let capacity = Capacities::Array(
arrays
.iter()
.filter_map(|arr| {
if !arr.is_null(0) {
Some(arr.len())
} else {
None
}
})
.sum(),
);

// ScalarValue::List contains a single element ListArray.
let nulls = arrays
.iter()
.map(|arr| arr.is_null(0))
.collect::<Vec<bool>>();
let arrays_data = arrays
.iter()
.filter(|arr| !arr.is_null(0))
.map(|arr| arr.to_data())
.collect::<Vec<_>>();

let arrays_ref = arrays_data.iter().collect::<Vec<_>>();
let mut mutable =
MutableArrayData::with_capacities(arrays_ref, true, capacity);

// ScalarValue::List contains a single element ListArray.
let mut index = 0;
for is_null in nulls.into_iter() {
if is_null {
mutable.extend_nulls(1);
} else {
// mutable array contains non-null elements
mutable.extend(index, 0, 1);
index += 1;
}
}
let data = mutable.freeze();
Ok(arrow_array::make_array(data))
}

let array: ArrayRef = match &data_type {
DataType::Decimal128(precision, scale) => {
let decimal_array =
Expand Down Expand Up @@ -1591,10 +1476,32 @@ impl ScalarValue {
DataType::Interval(IntervalUnit::MonthDayNano) => {
build_array_primitive!(IntervalMonthDayNanoArray, IntervalMonthDayNano)
}
DataType::Struct(_) => build_struct_array(scalars)?,
DataType::List(_)
| DataType::LargeList(_)
| DataType::FixedSizeList(_, _) => build_list_array(scalars)?,
DataType::FixedSizeList(_, _) => {
// arrow::compute::concat does not allow inconsistent types including the size of FixedSizeList.
// The length of nulls here we got is 1, so we need to resize the length of nulls to
// the length of non-nulls.
let mut arrays =
scalars.map(|s| s.to_array()).collect::<Result<Vec<_>>>()?;
let first_non_null_data_type = arrays
.iter()
.find(|sv| !sv.is_null(0))
.map(|sv| sv.data_type().to_owned());
if let Some(DataType::FixedSizeList(f, l)) = first_non_null_data_type {
for array in arrays.iter_mut() {
if array.is_null(0) {
*array =
Arc::new(FixedSizeListArray::new_null(f.clone(), l, 1));
}
}
}
let arrays = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
arrow::compute::concat(arrays.as_slice())?
}
DataType::List(_) | DataType::LargeList(_) | DataType::Struct(_) => {
let arrays = scalars.map(|s| s.to_array()).collect::<Result<Vec<_>>>()?;
let arrays = arrays.iter().map(|a| a.as_ref()).collect::<Vec<_>>();
arrow::compute::concat(arrays.as_slice())?
}
DataType::Dictionary(key_type, value_type) => {
// create the values array
let value_scalars = scalars
Expand Down Expand Up @@ -3529,6 +3436,44 @@ mod tests {
.collect()
}

#[test]
fn test_iter_to_array_fixed_size_list() {
let field = Arc::new(Field::new("item", DataType::Int32, true));
let f1 = Arc::new(FixedSizeListArray::new(
field.clone(),
3,
Arc::new(Int32Array::from(vec![1, 2, 3])),
None,
));
let f2 = Arc::new(FixedSizeListArray::new(
field.clone(),
3,
Arc::new(Int32Array::from(vec![4, 5, 6])),
None,
));
let f_nulls = Arc::new(FixedSizeListArray::new_null(field, 1, 1));

let scalars = vec![
ScalarValue::FixedSizeList(f_nulls.clone()),
ScalarValue::FixedSizeList(f1),
ScalarValue::FixedSizeList(f2),
ScalarValue::FixedSizeList(f_nulls),
];

let array = ScalarValue::iter_to_array(scalars).unwrap();

let expected = FixedSizeListArray::from_iter_primitive::<Int32Type, _, _>(
vec![
None,
Some(vec![Some(1), Some(2), Some(3)]),
Some(vec![Some(4), Some(5), Some(6)]),
None,
],
3,
);
assert_eq!(array.as_ref(), &expected);
}

#[test]
fn test_iter_to_array_struct() {
let s1 = StructArray::from(vec![
Expand Down

0 comments on commit f97a208

Please sign in to comment.