Skip to content

Commit

Permalink
address comment
Browse files Browse the repository at this point in the history
Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed Oct 31, 2023
1 parent 0d4dc36 commit 30282cb
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 19 deletions.
22 changes: 11 additions & 11 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::cast::{
};
use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err};
use crate::hash_utils::create_hashes;
use crate::utils::wrap_into_list_array;
use crate::utils::array_into_list_array;
use arrow::buffer::{NullBuffer, OffsetBuffer};
use arrow::compute::kernels::numeric::*;
use arrow::datatypes::{i256, FieldRef, Fields, SchemaBuilder};
Expand Down Expand Up @@ -1840,12 +1840,12 @@ impl ScalarValue {
let arr = Decimal128Array::from(vals)
.with_precision_and_scale(*precision, *scale)
.unwrap();
wrap_into_list_array(Arc::new(arr))
array_into_list_array(Arc::new(arr))
}

DataType::Null => {
let arr = new_null_array(&DataType::Null, values.len());
wrap_into_list_array(arr)
array_into_list_array(arr)
}
_ => panic!(
"Unsupported data type {:?} for ScalarValue::list_to_array",
Expand Down Expand Up @@ -2242,7 +2242,7 @@ impl ScalarValue {
let list_array = as_list_array(array);
let nested_array = list_array.value(index);
// Produces a single element `ListArray` with the value at `index`.
let arr = Arc::new(wrap_into_list_array(nested_array));
let arr = Arc::new(array_into_list_array(nested_array));

ScalarValue::List(arr)
}
Expand All @@ -2251,7 +2251,7 @@ impl ScalarValue {
let list_array = as_fixed_size_list_array(array)?;
let nested_array = list_array.value(index);
// Produces a single element `ListArray` with the value at `index`.
let arr = Arc::new(wrap_into_list_array(nested_array));
let arr = Arc::new(array_into_list_array(nested_array));

ScalarValue::List(arr)
}
Expand Down Expand Up @@ -3236,7 +3236,7 @@ mod tests {

let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8);

let expected = wrap_into_list_array(Arc::new(StringArray::from(vec![
let expected = array_into_list_array(Arc::new(StringArray::from(vec![
"rust",
"arrow",
"data-fusion",
Expand Down Expand Up @@ -3275,9 +3275,9 @@ mod tests {
#[test]
fn iter_to_array_string_test() {
let arr1 =
wrap_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
array_into_list_array(Arc::new(StringArray::from(vec!["foo", "bar", "baz"])));
let arr2 =
wrap_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"])));
array_into_list_array(Arc::new(StringArray::from(vec!["rust", "world"])));

let scalars = vec![
ScalarValue::List(Arc::new(arr1)),
Expand Down Expand Up @@ -4519,13 +4519,13 @@ mod tests {
// Define list-of-structs scalars

let nl0_array = ScalarValue::iter_to_array(vec![s0.clone(), s1.clone()]).unwrap();
let nl0 = ScalarValue::List(Arc::new(wrap_into_list_array(nl0_array)));
let nl0 = ScalarValue::List(Arc::new(array_into_list_array(nl0_array)));

let nl1_array = ScalarValue::iter_to_array(vec![s2.clone()]).unwrap();
let nl1 = ScalarValue::List(Arc::new(wrap_into_list_array(nl1_array)));
let nl1 = ScalarValue::List(Arc::new(array_into_list_array(nl1_array)));

let nl2_array = ScalarValue::iter_to_array(vec![s1.clone()]).unwrap();
let nl2 = ScalarValue::List(Arc::new(wrap_into_list_array(nl2_array)));
let nl2 = ScalarValue::List(Arc::new(array_into_list_array(nl2_array)));

// iter_to_array for list-of-struct
let array = ScalarValue::iter_to_array(vec![nl0, nl1, nl2]).unwrap();
Expand Down
46 changes: 44 additions & 2 deletions datafusion/common/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

//! This module provides the bisect function, which implements binary search.
use crate::error::_internal_err;
use crate::{DataFusionError, Result, ScalarValue};
use arrow::array::{ArrayRef, PrimitiveArray};
use arrow::buffer::OffsetBuffer;
use arrow::compute;
use arrow::compute::{partition, SortColumn, SortOptions};
use arrow::datatypes::{Field, SchemaRef, UInt32Type};
use arrow::record_batch::RecordBatch;
use arrow_array::ListArray;
use arrow_array::{Array, ListArray};
use sqlparser::ast::Ident;
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::Parser;
Expand Down Expand Up @@ -338,7 +339,7 @@ pub fn longest_consecutive_prefix<T: Borrow<usize>>(

/// Wrap an array into a single element `ListArray`.
/// For example `[1, 2, 3]` would be converted into `[[1, 2, 3]]`
pub fn wrap_into_list_array(arr: ArrayRef) -> ListArray {
pub fn array_into_list_array(arr: ArrayRef) -> ListArray {
let offsets = OffsetBuffer::from_lengths([arr.len()]);
ListArray::new(
Arc::new(Field::new("item", arr.data_type().to_owned(), true)),
Expand All @@ -348,6 +349,47 @@ pub fn wrap_into_list_array(arr: ArrayRef) -> ListArray {
)
}

/// Wrap arrays into a single element `ListArray`.
///
/// Example:
/// ```
/// use arrow::array::{Int32Array, ListArray, ArrayRef};
/// use arrow::datatypes::{Int32Type, Field};
/// use std::sync::Arc;
///
/// let arr1 = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
/// let arr2 = Arc::new(Int32Array::from(vec![4, 5, 6])) as ArrayRef;
///
/// let list_arr = datafusion_common::utils::arrays_into_list_array([arr1, arr2]).unwrap();
///
/// let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(
/// vec![
/// Some(vec![Some(1), Some(2), Some(3)]),
/// Some(vec![Some(4), Some(5), Some(6)]),
/// ]
/// );
///
/// assert_eq!(list_arr, expected);
pub fn arrays_into_list_array(
arr: impl IntoIterator<Item = ArrayRef>,
) -> Result<ListArray> {
let arr = arr.into_iter().collect::<Vec<_>>();
if arr.is_empty() {
return _internal_err!("Cannot wrap empty array into list array");
}

let lens = arr.iter().map(|x| x.len()).collect::<Vec<_>>();
// Assume data type is consistent
let data_type = arr[0].data_type().to_owned();
let values = arr.iter().map(|x| x.as_ref()).collect::<Vec<_>>();
Ok(ListArray::new(
Arc::new(Field::new("item", data_type, true)),
OffsetBuffer::from_lengths(lens),
arrow::compute::concat(values.as_slice())?,
None,
))
}

/// An extension trait for smart pointers. Provides an interface to get a
/// raw pointer to the data (with metadata stripped away).
///
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use arrow::array::ArrayRef;
use arrow::datatypes::{DataType, Field};
use arrow_array::Array;
use datafusion_common::cast::as_list_array;
use datafusion_common::utils::wrap_into_list_array;
use datafusion_common::utils::array_into_list_array;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
use datafusion_expr::Accumulator;
Expand Down Expand Up @@ -161,7 +161,7 @@ impl Accumulator for ArrayAggAccumulator {
}

let concated_array = arrow::compute::concat(&element_arrays)?;
let list_array = wrap_into_list_array(concated_array);
let list_array = array_into_list_array(concated_array);

Ok(ScalarValue::List(Arc::new(list_array)))
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ mod tests {
use arrow_array::types::Int32Type;
use arrow_array::{Array, ListArray};
use arrow_buffer::OffsetBuffer;
use datafusion_common::utils::wrap_into_list_array;
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{internal_err, DataFusionError};

// arrow::compute::sort cann't sort ListArray directly, so we need to sort the inner primitive array and wrap it back into ListArray.
Expand All @@ -201,7 +201,7 @@ mod tests {
};

let arr = arrow::compute::sort(&arr, None).unwrap();
let list_arr = wrap_into_list_array(arr);
let list_arr = array_into_list_array(arr);
ScalarValue::List(Arc::new(list_arr))
}

Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use arrow_buffer::NullBuffer;
use datafusion_common::cast::{
as_generic_string_array, as_int64_array, as_list_array, as_string_array,
};
use datafusion_common::utils::wrap_into_list_array;
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_err, DataFusionError, Result,
};
Expand Down Expand Up @@ -412,7 +412,7 @@ pub fn make_array(arrays: &[ArrayRef]) -> Result<ArrayRef> {
// Either an empty array or all nulls:
DataType::Null => {
let array = new_null_array(&DataType::Null, arrays.len());
Ok(Arc::new(wrap_into_list_array(array)))
Ok(Arc::new(array_into_list_array(array)))
}
data_type => array_array(arrays, data_type),
}
Expand Down

0 comments on commit 30282cb

Please sign in to comment.