From d1ee07c5f8ecb0ebf8e62d92c3742e6da10a8af3 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Sat, 17 Sep 2022 00:22:03 -0700 Subject: [PATCH 1/4] Add divide_dyn_opt kernel --- arrow/src/compute/kernels/arithmetic.rs | 78 +++++++++++++++++++++++++ arrow/src/compute/kernels/arity.rs | 49 ++++++---------- 2 files changed, 97 insertions(+), 30 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index b44cb8b947e2..8024831f74ec 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -697,6 +697,39 @@ where ) } +#[cfg(feature = "dyn_arith_dict")] +fn math_safe_divide_op_dict( + left: &DictionaryArray, + right: &DictionaryArray, + op: F, +) -> Result +where + K: ArrowNumericType, + T: ArrowNumericType, + T::Native: One + Zero, + F: Fn(T::Native, T::Native) -> Option, +{ + let left = left.downcast_dict::>().unwrap(); + let right = right.downcast_dict::>().unwrap(); + let array: PrimitiveArray = binary_opt::<_, _, _, T>(left, right, op)?; + Ok(Arc::new(array) as ArrayRef) +} + +fn math_safe_divide_op( + left: &PrimitiveArray, + right: &PrimitiveArray, + op: F, +) -> Result +where + LT: ArrowNumericType, + RT: ArrowNumericType, + RT::Native: One + Zero, + F: Fn(LT::Native, RT::Native) -> Option, +{ + let array: PrimitiveArray = binary_opt::<_, _, _, LT>(left, right, op)?; + Ok(Arc::new(array) as ArrayRef) +} + /// Perform `left + right` operation on two arrays. If either left or right value is null /// then the result is also null. /// @@ -1406,6 +1439,51 @@ pub fn divide_dyn_checked(left: &dyn Array, right: &dyn Array) -> Result Result { + match left.data_type() { + DataType::Dictionary(_, _) => { + typed_dict_math_op!( + left, + right, + |a, b| { + if b.is_zero() { + None + } else { + Some(a.div_wrapping(b)) + } + }, + math_safe_divide_op_dict + ) + } + _ => { + downcast_primitive_array!( + (left, right) => { + math_safe_divide_op(left, right, |a, b| { + if b.is_zero() { + None + } else { + Some(a.div_wrapping(b)) + } + }) + } + _ => Err(ArrowError::CastError(format!( + "Unsupported data type {}, {}", + left.data_type(), right.data_type() + ))) + ) + } + } +} + /// Perform `left / right` operation on two arrays without checking for division by zero. /// For floating point types, the result of dividing by zero follows normal floating point /// rules. For other numeric types, dividing by zero will panic, diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index bf10289683f1..ce02180f2515 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -367,16 +367,14 @@ where /// # Error /// /// This function gives error if the arrays have different lengths -pub(crate) fn binary_opt( - a: &PrimitiveArray, - b: &PrimitiveArray, +pub(crate) fn binary_opt( + a: A, + b: B, op: F, ) -> Result> where - A: ArrowPrimitiveType, - B: ArrowPrimitiveType, O: ArrowPrimitiveType, - F: Fn(A::Native, B::Native) -> Option, + F: Fn(A::Item, B::Item) -> Option, { if a.len() != b.len() { return Err(ArrowError::ComputeError( @@ -388,30 +386,21 @@ where return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE))); } - if a.null_count() == 0 && b.null_count() == 0 { - Ok(a.values() - .iter() - .zip(b.values().iter()) - .map(|(a, b)| op(*a, *b)) - .collect()) - } else { - let iter_a = ArrayIter::new(a); - let iter_b = ArrayIter::new(b); - - let values = - iter_a - .into_iter() - .zip(iter_b.into_iter()) - .map(|(item_a, item_b)| { - if let (Some(a), Some(b)) = (item_a, item_b) { - op(a, b) - } else { - None - } - }); - - Ok(values.collect()) - } + let iter_a = ArrayIter::new(a); + let iter_b = ArrayIter::new(b); + + let values = iter_a + .into_iter() + .zip(iter_b.into_iter()) + .map(|(item_a, item_b)| { + if let (Some(a), Some(b)) = (item_a, item_b) { + op(a, b) + } else { + None + } + }); + + Ok(values.collect()) } #[cfg(test)] From 4637a48dafa945580369dc84528e3c7ab3ddabfe Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 20 Sep 2022 22:55:50 -0700 Subject: [PATCH 2/4] Add test --- arrow/src/compute/kernels/arithmetic.rs | 24 ++++++++++++++++++++++++ arrow/src/compute/kernels/arity.rs | 24 ++++++++++++++++++++++++ 2 files changed, 48 insertions(+) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index 8024831f74ec..ca5e7a8279be 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -2830,4 +2830,28 @@ mod tests { let overflow = divide_dyn_checked(&a, &b); overflow.expect_err("overflow should be detected"); } + + #[test] + #[cfg(feature = "dyn_arith_dict")] + fn test_div_dyn_opt_overflow_division_by_zero() { + let a = Int32Array::from(vec![i32::MIN]); + let b = Int32Array::from(vec![0]); + + let division_by_zero = divide_dyn_opt(&a, &b); + let expected = Arc::new(Int32Array::from(vec![None])) as ArrayRef; + assert_eq!(&expected, &division_by_zero.unwrap()); + + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(i32::MIN).unwrap(); + let a = builder.finish(); + + let mut builder = + PrimitiveDictionaryBuilder::::with_capacity(1, 1); + builder.append(0).unwrap(); + let b = builder.finish(); + + let division_by_zero = divide_dyn_opt(&a, &b); + assert_eq!(&expected, &division_by_zero.unwrap()); + } } diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index ce02180f2515..751eedeb5b4b 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -357,6 +357,26 @@ where Ok(unsafe { build_primitive_array(len, buffer.into(), 0, None) }) } +#[inline(never)] +fn try_binary_opt_no_nulls( + len: usize, + a: A, + b: B, + op: F, +) -> Result> +where + O: ArrowPrimitiveType, + F: Fn(A::Item, B::Item) -> Option, +{ + let mut buffer = Vec::with_capacity(10); + for idx in 0..len { + unsafe { + buffer.push(op(a.value_unchecked(idx), b.value_unchecked(idx))); + }; + } + Ok(buffer.iter().collect()) +} + /// Applies the provided binary operation across `a` and `b`, collecting the optional results /// into a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the corresponding /// index in the result will also be null. The binary operation could return `None` which @@ -386,6 +406,10 @@ where return Ok(PrimitiveArray::from(ArrayData::new_empty(&O::DATA_TYPE))); } + if a.null_count() == 0 && b.null_count() == 0 { + return Ok(try_binary_opt_no_nulls(a.len(), a, b, op)?); + } + let iter_a = ArrayIter::new(a); let iter_b = ArrayIter::new(b); From 4047fc1d79036fb5e68a78ddb12f0d3dc06d9783 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 21 Sep 2022 00:29:49 -0700 Subject: [PATCH 3/4] Fix clippy --- arrow/src/compute/kernels/arity.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow/src/compute/kernels/arity.rs b/arrow/src/compute/kernels/arity.rs index 751eedeb5b4b..5f875e6ddf29 100644 --- a/arrow/src/compute/kernels/arity.rs +++ b/arrow/src/compute/kernels/arity.rs @@ -407,7 +407,7 @@ where } if a.null_count() == 0 && b.null_count() == 0 { - return Ok(try_binary_opt_no_nulls(a.len(), a, b, op)?); + return try_binary_opt_no_nulls(a.len(), a, b, op); } let iter_a = ArrayIter::new(a); From 8bcc797d3019318e69786264c8f5bb470cf48dae Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Wed, 21 Sep 2022 18:32:32 -0700 Subject: [PATCH 4/4] Rename function --- arrow/src/compute/kernels/arithmetic.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/arrow/src/compute/kernels/arithmetic.rs b/arrow/src/compute/kernels/arithmetic.rs index ca5e7a8279be..d33827594af5 100644 --- a/arrow/src/compute/kernels/arithmetic.rs +++ b/arrow/src/compute/kernels/arithmetic.rs @@ -698,7 +698,7 @@ where } #[cfg(feature = "dyn_arith_dict")] -fn math_safe_divide_op_dict( +fn math_divide_safe_op_dict( left: &DictionaryArray, right: &DictionaryArray, op: F, @@ -1461,7 +1461,7 @@ pub fn divide_dyn_opt(left: &dyn Array, right: &dyn Array) -> Result { Some(a.div_wrapping(b)) } }, - math_safe_divide_op_dict + math_divide_safe_op_dict ) } _ => {