Skip to content

Commit

Permalink
Make arithmetic kernels supports dictionary of decimal array (#3255)
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya authored Dec 2, 2022
1 parent 2da6aab commit 9abdb55
Showing 1 changed file with 78 additions and 2 deletions.
80 changes: 78 additions & 2 deletions arrow/src/compute/kernels/arithmetic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ use crate::datatypes::{
};
#[cfg(feature = "dyn_arith_dict")]
use crate::datatypes::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
Decimal128Type, Decimal256Type, Float32Type, Float64Type, Int16Type, Int32Type,
Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use crate::error::{ArrowError, Result};
use crate::{datatypes, downcast_primitive_array};
Expand Down Expand Up @@ -461,6 +461,14 @@ macro_rules! typed_dict_op {
let array = $MATH_OP::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP)?;
Ok(Arc::new(array))
}
(DataType::Decimal128(_, s1), DataType::Decimal128(_, s2)) if s1 == s2 => {
let array = $MATH_OP::<$KT, Decimal128Type, _>($LEFT, $RIGHT, $OP)?;
Ok(Arc::new(array))
}
(DataType::Decimal256(_, s1), DataType::Decimal256(_, s2)) if s1 == s2 => {
let array = $MATH_OP::<$KT, Decimal256Type, _>($LEFT, $RIGHT, $OP)?;
Ok(Arc::new(array))
}
(t1, t2) => Err(ArrowError::CastError(format!(
"Cannot perform arithmetic operation on two dictionary arrays of different value types ({} and {})",
t1, t2
Expand Down Expand Up @@ -3150,4 +3158,72 @@ mod tests {
let overflow = try_unary_mut(a, |value| value.add_checked(1));
let _ = overflow.unwrap().expect_err("overflow should be detected");
}

#[test]
#[cfg(feature = "dyn_arith_dict")]
fn test_dict_decimal() {
let values = Decimal128Array::from_iter_values([0, 1, 2, 3, 4, 5]);
let keys = Int8Array::from_iter_values([1_i8, 2, 5, 4, 3, 0]);
let array1 = DictionaryArray::try_new(&keys, &values).unwrap();

let values = Decimal128Array::from_iter_values([7, -3, 4, 3, 5]);
let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 3, 4]);
let array2 = DictionaryArray::try_new(&keys, &values).unwrap();

let result = add_dyn(&array1, &array2).unwrap();
let expected =
Arc::new(Decimal128Array::from(vec![8, 9, 2, 8, 6, 5])) as ArrayRef;
assert_eq!(&result, &expected);

let result = subtract_dyn(&array1, &array2).unwrap();
let expected =
Arc::new(Decimal128Array::from(vec![-6, -5, 8, 0, 0, -5])) as ArrayRef;
assert_eq!(&result, &expected);

let values = Decimal256Array::from_iter_values([
i256::from_i128(0),
i256::from_i128(1),
i256::from_i128(2),
i256::from_i128(3),
i256::from_i128(4),
i256::from_i128(5),
]);
let keys =
Int8Array::from(vec![Some(1_i8), None, Some(5), Some(4), Some(3), None]);
let array1 = DictionaryArray::try_new(&keys, &values).unwrap();

let values = Decimal256Array::from_iter_values([
i256::from_i128(7),
i256::from_i128(-3),
i256::from_i128(4),
i256::from_i128(3),
i256::from_i128(5),
]);
let keys =
Int8Array::from(vec![Some(0_i8), Some(0), None, Some(2), Some(3), Some(4)]);
let array2 = DictionaryArray::try_new(&keys, &values).unwrap();

let result = add_dyn(&array1, &array2).unwrap();
let expected = Arc::new(Decimal256Array::from(vec![
Some(i256::from_i128(8)),
None,
None,
Some(i256::from_i128(8)),
Some(i256::from_i128(6)),
None,
])) as ArrayRef;

assert_eq!(&result, &expected);

let result = subtract_dyn(&array1, &array2).unwrap();
let expected = Arc::new(Decimal256Array::from(vec![
Some(i256::from_i128(-6)),
None,
None,
Some(i256::from_i128(0)),
Some(i256::from_i128(0)),
None,
])) as ArrayRef;
assert_eq!(&result, &expected);
}
}

0 comments on commit 9abdb55

Please sign in to comment.