From 29ac9364c402a2dae4de4e3beabf479dacd3e752 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 9 Dec 2023 20:01:56 +0800 Subject: [PATCH] fix offset trait Signed-off-by: jayzhan211 --- datafusion/common/src/scalar.rs | 54 +++++++++---------- datafusion/core/tests/sql/aggregates.rs | 8 ++- .../src/aggregate/array_agg_distinct.rs | 2 +- .../src/aggregate/array_agg_ordered.rs | 4 +- .../src/aggregate/count_distinct.rs | 2 +- 5 files changed, 37 insertions(+), 33 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index b2589f46bb8fd..4c08538843cb2 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -2071,13 +2071,13 @@ impl ScalarValue { /// /// assert_eq!(scalar_vec, expected); /// ``` - pub fn convert_list_array_to_scalar_vec(array: &dyn Array) -> Result>> { - - - if as_list_array(array).is_ok() { - Self::convert_list_array_to_scalar_vec_internal(array) + pub fn convert_list_array_to_scalar_vec( + array: &dyn Array, + ) -> Result>> { + if array.as_list_opt::().is_some() { + Self::convert_list_array_to_scalar_vec_internal::(array) } else { - _internal_err!("Expected ListArray but found: {array:?}") + _internal_err!("Expected GenericListArray but found: {array:?}") } } @@ -2086,18 +2086,18 @@ impl ScalarValue { ) -> Result>> { let mut scalars_vec = Vec::with_capacity(array.len()); - let list_arr = as_generic_list_array::(array); - - if let Ok(list_arr) = as_list_array(array) { + if let Some(list_arr) = array.as_list_opt::() { for index in 0..list_arr.len() { let scalars = match list_arr.is_null(index) { true => Vec::new(), false => { let nested_array = list_arr.value(index); - Self::convert_list_array_to_scalar_vec_internal(&nested_array)? - .into_iter() - .flatten() - .collect() + Self::convert_list_array_to_scalar_vec_internal::( + &nested_array, + )? + .into_iter() + .flatten() + .collect() } }; scalars_vec.push(scalars); @@ -2106,6 +2106,7 @@ impl ScalarValue { let scalars = ScalarValue::convert_non_list_array_to_scalars(array)?; scalars_vec.push(scalars); } + Ok(scalars_vec) } @@ -2134,16 +2135,16 @@ impl ScalarValue { /// assert_eq!(scalar_vec, expected); /// ``` pub fn convert_non_list_array_to_scalars(array: &dyn Array) -> Result> { - if as_list_array(array).is_ok() { - _internal_err!("Expected non-ListArray but found: {array:?}") - } else { - let mut scalars = Vec::with_capacity(array.len()); - for index in 0..array.len() { - let scalar = ScalarValue::try_from_array(array, index)?; - scalars.push(scalar); - } - Ok(scalars) + if array.as_list_opt::().is_some() || array.as_list_opt::().is_some() { + return _internal_err!("Expected non ListArray but found: {array:?}"); } + + let mut scalars = Vec::with_capacity(array.len()); + for index in 0..array.len() { + let scalar = ScalarValue::try_from_array(array, index)?; + scalars.push(scalar); + } + Ok(scalars) } // TODO: Support more types after other ScalarValue is wrapped with ArrayRef @@ -2194,7 +2195,7 @@ impl ScalarValue { typed_cast!(array, index, LargeStringArray, LargeUtf8)? } DataType::List(_) => { - let list_array = as_list_array(array)?; + let list_array = as_list_array(array); let nested_array = list_array.value(index); // Produces a single element `ListArray` with the value at `index`. let arr = Arc::new(array_into_list_array(nested_array)); @@ -3163,7 +3164,6 @@ impl ScalarType for TimestampNanosecondType { } #[cfg(test)] -#[cfg(feature = "parquet")] mod tests { use super::*; @@ -3202,7 +3202,7 @@ mod tests { let l12 = arrays_into_list_array([l1, l2]).unwrap(); let arr = Arc::new(l12) as ArrayRef; - let actual = ScalarValue::convert_list_array_to_scalar_vec(&arr).unwrap(); + let actual = ScalarValue::convert_list_array_to_scalar_vec::(&arr).unwrap(); let expected = vec![ vec![ ScalarValue::Int32(Some(1)), @@ -3232,7 +3232,7 @@ mod tests { let actual_arr = sv .to_array_of_size(2) .expect("Failed to convert to array of size"); - let actual_list_arr = as_list_array(&actual_arr).unwrap(); + let actual_list_arr = as_list_array(&actual_arr); let arr = ListArray::from_iter_primitive::(vec![ Some(vec![Some(1), None, Some(2)]), @@ -3272,7 +3272,7 @@ mod tests { ]; let array = ScalarValue::new_list(scalars.as_slice(), &DataType::Utf8); - let result = as_list_array(&array).unwrap(); + let result = as_list_array(&array); let expected = array_into_list_array(Arc::new(StringArray::from(vec![ "rust", diff --git a/datafusion/core/tests/sql/aggregates.rs b/datafusion/core/tests/sql/aggregates.rs index 5773792488456..68c58df41cb6c 100644 --- a/datafusion/core/tests/sql/aggregates.rs +++ b/datafusion/core/tests/sql/aggregates.rs @@ -45,13 +45,17 @@ async fn csv_query_array_agg_distinct() -> Result<()> { let column = actual[0].column(0); assert_eq!(column.len(), 1); - let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec(&column)?; - let mut scalars = scalar_vec[0].clone(); + // 1 row + let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec::(&column)?; + // workaround lack of Ord of ScalarValue let cmp = |a: &ScalarValue, b: &ScalarValue| { a.partial_cmp(b).expect("Can compare ScalarValues") }; + + let mut scalars = scalar_vec.first().unwrap().to_owned(); scalars.sort_by(cmp); + assert_eq!( scalars, vec![ diff --git a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs index f905e63728a04..f6eaee58f50e0 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_distinct.rs @@ -139,7 +139,7 @@ impl Accumulator for DistinctArrayAggAccumulator { let array = &values[0]; match array.data_type() { DataType::List(_) => { - let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec(array)?; + let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec::(array)?; for scalars in scalar_vec { self.values.extend(scalars); } diff --git a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs index 81287ba5ec568..729220b2fff0d 100644 --- a/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs +++ b/datafusion/physical-expr/src/aggregate/array_agg_ordered.rs @@ -225,13 +225,13 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator { partition_ordering_values.push(self.ordering_values.clone()); let array_agg_res = - ScalarValue::convert_list_array_to_scalar_vec(array_agg_values)?; + ScalarValue::convert_list_array_to_scalar_vec::(array_agg_values)?; for v in array_agg_res.into_iter() { partition_values.push(v); } - let orderings = ScalarValue::convert_list_array_to_scalar_vec(agg_orderings)?; + let orderings = ScalarValue::convert_list_array_to_scalar_vec::(agg_orderings)?; // Ordering requirement expression values for each entry in the ARRAY_AGG list let other_ordering_values = self.convert_array_agg_to_orderings(orderings)?; for v in other_ordering_values.into_iter() { diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 1f38a8875aa30..332e1b690b49e 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -167,7 +167,7 @@ impl Accumulator for DistinctCountAccumulator { return Ok(()); } assert_eq!(states.len(), 1, "array_agg states must be singleton!"); - let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec(&states[0])?; + let scalar_vec = ScalarValue::convert_list_array_to_scalar_vec::(&states[0])?; for scalars in scalar_vec.into_iter() { self.values.extend(scalars) }