From 7a2851123fd2009a63c71ba70edf18a278a08fe5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Aug 2022 14:53:53 -0700 Subject: [PATCH 1/5] Add max_dyn and min_dyn --- arrow/src/compute/kernels/aggregate.rs | 95 ++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) diff --git a/arrow/src/compute/kernels/aggregate.rs b/arrow/src/compute/kernels/aggregate.rs index c8d0443c4706..5358bd566faa 100644 --- a/arrow/src/compute/kernels/aggregate.rs +++ b/arrow/src/compute/kernels/aggregate.rs @@ -215,6 +215,70 @@ where } } +/// Returns the min of values in the array. +pub fn min_dyn>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: Add, +{ + match array.data_type() { + DataType::Dictionary(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + let mut has_value = false; + let mut n = T::default_value(); + let iter = ArrayIter::new(array); + iter.into_iter().for_each(|value| { + if let Some(value) = value { + if !has_value || value < n { + has_value = true; + n = value; + } + } + }); + + Some(n) + } + _ => min::(as_primitive_array(&array)), + } +} + +/// Returns the max of values in the array. +pub fn max_dyn>(array: A) -> Option +where + T: ArrowNumericType, + T::Native: Add, +{ + match array.data_type() { + DataType::Dictionary(_, _) => { + let null_count = array.null_count(); + + if null_count == array.len() { + return None; + } + + let mut has_value = false; + let mut n = T::default_value(); + let iter = ArrayIter::new(array); + iter.into_iter().for_each(|value| { + if let Some(value) = value { + if !has_value || value > n { + has_value = true; + n = value; + } + } + }); + + Some(n) + } + _ => max::(as_primitive_array(&array)), + } +} + /// Returns the sum of values in the primitive array. /// /// Returns `None` if the array is empty or only contains null values. @@ -1058,4 +1122,35 @@ mod tests { let array = dict_array.downcast_dict::().unwrap(); assert!(sum_dyn::(array).is_none()); } + + #[test] + fn test_max_min_dyn() { + let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); + let keys = Int8Array::from_iter_values([2_i8, 3, 4]); + + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(14, max_dyn::(array).unwrap()); + + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(12, min_dyn::(array).unwrap()); + + let a = Int32Array::from(vec![1, 2, 3, 4, 5]); + assert_eq!(5, max_dyn::(&a).unwrap()); + assert_eq!(1, min_dyn::(&a).unwrap()); + + let keys = Int8Array::from(vec![Some(2_i8), None, Some(7)]); + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(17, max_dyn::(array).unwrap()); + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(12, min_dyn::(array).unwrap()); + + let keys = Int8Array::from(vec![None, None, None]); + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert!(max_dyn::(array).is_none()); + let array = dict_array.downcast_dict::().unwrap(); + assert!(min_dyn::(array).is_none()); + } } From 755b5ec0c7944e639a83444a9fb9f9862f4215af Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Aug 2022 18:23:37 -0700 Subject: [PATCH 2/5] Add a helper function --- arrow/src/compute/kernels/aggregate.rs | 46 ++++++++++---------------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/arrow/src/compute/kernels/aggregate.rs b/arrow/src/compute/kernels/aggregate.rs index 5358bd566faa..5f4af6f33436 100644 --- a/arrow/src/compute/kernels/aggregate.rs +++ b/arrow/src/compute/kernels/aggregate.rs @@ -219,39 +219,29 @@ where pub fn min_dyn>(array: A) -> Option where T: ArrowNumericType, - T::Native: Add, + T::Native: ArrowNativeType, { - match array.data_type() { - DataType::Dictionary(_, _) => { - let null_count = array.null_count(); - - if null_count == array.len() { - return None; - } - - let mut has_value = false; - let mut n = T::default_value(); - let iter = ArrayIter::new(array); - iter.into_iter().for_each(|value| { - if let Some(value) = value { - if !has_value || value < n { - has_value = true; - n = value; - } - } - }); - - Some(n) - } - _ => min::(as_primitive_array(&array)), - } + min_max_dyn_helper::(array, |a, b| a < b, min) } /// Returns the max of values in the array. pub fn max_dyn>(array: A) -> Option where T: ArrowNumericType, - T::Native: Add, + T::Native: ArrowNativeType, +{ + min_max_dyn_helper::(array, |a, b| a > b, max) +} + +fn min_max_dyn_helper, F, M>( + array: A, + cmp: F, + m: M, +) -> Option +where + T: ArrowNumericType, + F: Fn(&T::Native, &T::Native) -> bool, + M: Fn(&PrimitiveArray) -> Option, { match array.data_type() { DataType::Dictionary(_, _) => { @@ -266,7 +256,7 @@ where let iter = ArrayIter::new(array); iter.into_iter().for_each(|value| { if let Some(value) = value { - if !has_value || value > n { + if !has_value || cmp(&value, &n) { has_value = true; n = value; } @@ -275,7 +265,7 @@ where Some(n) } - _ => max::(as_primitive_array(&array)), + _ => m(as_primitive_array(&array)), } } From eb95d21f93716a45d51d4e9fd81f36149e46ad41 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 24 Aug 2022 18:34:02 -0700 Subject: [PATCH 3/5] Add NaN handling and test --- arrow/src/compute/kernels/aggregate.rs | 27 +++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/arrow/src/compute/kernels/aggregate.rs b/arrow/src/compute/kernels/aggregate.rs index 5f4af6f33436..fc7914c92bc9 100644 --- a/arrow/src/compute/kernels/aggregate.rs +++ b/arrow/src/compute/kernels/aggregate.rs @@ -221,7 +221,11 @@ where T: ArrowNumericType, T::Native: ArrowNativeType, { - min_max_dyn_helper::(array, |a, b| a < b, min) + min_max_dyn_helper::( + array, + |a, b| (!is_nan(*a) & is_nan(*b)) || a < b, + min, + ) } /// Returns the max of values in the array. @@ -230,7 +234,11 @@ where T: ArrowNumericType, T::Native: ArrowNativeType, { - min_max_dyn_helper::(array, |a, b| a > b, max) + min_max_dyn_helper::( + array, + |a, b| (is_nan(*a) & !is_nan(*b)) || a > b, + max, + ) } fn min_max_dyn_helper, F, M>( @@ -710,7 +718,7 @@ mod tests { use super::*; use crate::array::*; use crate::compute::add; - use crate::datatypes::{Int32Type, Int8Type}; + use crate::datatypes::{Float32Type, Int32Type, Int8Type}; #[test] fn test_primitive_array_sum() { @@ -1143,4 +1151,17 @@ mod tests { let array = dict_array.downcast_dict::().unwrap(); assert!(min_dyn::(array).is_none()); } + + #[test] + fn test_max_min_dyn_nan() { + let values = Float32Array::from(vec![5.0_f32, 2.0_f32, f32::NAN]); + let keys = Int8Array::from_iter_values([0_i8, 1, 2]); + + let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); + let array = dict_array.downcast_dict::().unwrap(); + assert!(max_dyn::(array).unwrap().is_nan()); + + let array = dict_array.downcast_dict::().unwrap(); + assert_eq!(2.0_f32, min_dyn::(array).unwrap()); + } } From 675965f74d2b9e69ab67168c3d1c8ba43298f774 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 25 Aug 2022 00:20:25 -0700 Subject: [PATCH 4/5] Rename to min_array, max_array and sum_array --- arrow/src/compute/kernels/aggregate.rs | 34 +++++++++++++------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/arrow/src/compute/kernels/aggregate.rs b/arrow/src/compute/kernels/aggregate.rs index fc7914c92bc9..b930405b6fb7 100644 --- a/arrow/src/compute/kernels/aggregate.rs +++ b/arrow/src/compute/kernels/aggregate.rs @@ -185,7 +185,7 @@ pub fn min_string(array: &GenericStringArray) -> Option<& } /// Returns the sum of values in the array. -pub fn sum_dyn>(array: A) -> Option +pub fn sum_array>(array: A) -> Option where T: ArrowNumericType, T::Native: Add, @@ -216,7 +216,7 @@ where } /// Returns the min of values in the array. -pub fn min_dyn>(array: A) -> Option +pub fn min_array>(array: A) -> Option where T: ArrowNumericType, T::Native: ArrowNativeType, @@ -229,7 +229,7 @@ where } /// Returns the max of values in the array. -pub fn max_dyn>(array: A) -> Option +pub fn max_array>(array: A) -> Option where T: ArrowNumericType, T::Native: ArrowNativeType, @@ -1105,20 +1105,20 @@ mod tests { let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); let array = dict_array.downcast_dict::().unwrap(); - assert_eq!(39, sum_dyn::(array).unwrap()); + assert_eq!(39, sum_array::(array).unwrap()); let a = Int32Array::from(vec![1, 2, 3, 4, 5]); - assert_eq!(15, sum_dyn::(&a).unwrap()); + assert_eq!(15, sum_array::(&a).unwrap()); let keys = Int8Array::from(vec![Some(2_i8), None, Some(4)]); let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); let array = dict_array.downcast_dict::().unwrap(); - assert_eq!(26, sum_dyn::(array).unwrap()); + assert_eq!(26, sum_array::(array).unwrap()); let keys = Int8Array::from(vec![None, None, None]); let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); let array = dict_array.downcast_dict::().unwrap(); - assert!(sum_dyn::(array).is_none()); + assert!(sum_array::(array).is_none()); } #[test] @@ -1128,28 +1128,28 @@ mod tests { let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); let array = dict_array.downcast_dict::().unwrap(); - assert_eq!(14, max_dyn::(array).unwrap()); + assert_eq!(14, max_array::(array).unwrap()); let array = dict_array.downcast_dict::().unwrap(); - assert_eq!(12, min_dyn::(array).unwrap()); + assert_eq!(12, min_array::(array).unwrap()); let a = Int32Array::from(vec![1, 2, 3, 4, 5]); - assert_eq!(5, max_dyn::(&a).unwrap()); - assert_eq!(1, min_dyn::(&a).unwrap()); + assert_eq!(5, max_array::(&a).unwrap()); + assert_eq!(1, min_array::(&a).unwrap()); let keys = Int8Array::from(vec![Some(2_i8), None, Some(7)]); let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); let array = dict_array.downcast_dict::().unwrap(); - assert_eq!(17, max_dyn::(array).unwrap()); + assert_eq!(17, max_array::(array).unwrap()); let array = dict_array.downcast_dict::().unwrap(); - assert_eq!(12, min_dyn::(array).unwrap()); + assert_eq!(12, min_array::(array).unwrap()); let keys = Int8Array::from(vec![None, None, None]); let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); let array = dict_array.downcast_dict::().unwrap(); - assert!(max_dyn::(array).is_none()); + assert!(max_array::(array).is_none()); let array = dict_array.downcast_dict::().unwrap(); - assert!(min_dyn::(array).is_none()); + assert!(min_array::(array).is_none()); } #[test] @@ -1159,9 +1159,9 @@ mod tests { let dict_array = DictionaryArray::try_new(&keys, &values).unwrap(); let array = dict_array.downcast_dict::().unwrap(); - assert!(max_dyn::(array).unwrap().is_nan()); + assert!(max_array::(array).unwrap().is_nan()); let array = dict_array.downcast_dict::().unwrap(); - assert_eq!(2.0_f32, min_dyn::(array).unwrap()); + assert_eq!(2.0_f32, min_array::(array).unwrap()); } } From c886a077d5ff2e5d2379db07ecc68e885a6718d4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 25 Aug 2022 00:40:54 -0700 Subject: [PATCH 5/5] Rename min_max_dyn_helper --- arrow/src/compute/kernels/aggregate.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/arrow/src/compute/kernels/aggregate.rs b/arrow/src/compute/kernels/aggregate.rs index b930405b6fb7..fb2f55582d65 100644 --- a/arrow/src/compute/kernels/aggregate.rs +++ b/arrow/src/compute/kernels/aggregate.rs @@ -221,7 +221,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeType, { - min_max_dyn_helper::( + min_max_array_helper::( array, |a, b| (!is_nan(*a) & is_nan(*b)) || a < b, min, @@ -234,14 +234,14 @@ where T: ArrowNumericType, T::Native: ArrowNativeType, { - min_max_dyn_helper::( + min_max_array_helper::( array, |a, b| (is_nan(*a) & !is_nan(*b)) || a > b, max, ) } -fn min_max_dyn_helper, F, M>( +fn min_max_array_helper, F, M>( array: A, cmp: F, m: M,