diff --git a/vortex-array/src/arrays/chunked/compute/sum.rs b/vortex-array/src/arrays/chunked/compute/sum.rs index a849e7e5fb3..d2872f40894 100644 --- a/vortex-array/src/arrays/chunked/compute/sum.rs +++ b/vortex-array/src/arrays/chunked/compute/sum.rs @@ -2,9 +2,10 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use num_traits::PrimInt; -use vortex_dtype::{NativePType, PType, match_each_native_ptype}; -use vortex_error::{VortexExpect, VortexResult, vortex_err}; -use vortex_scalar::{FromPrimitiveOrF16, Scalar}; +use vortex_dtype::Nullability::Nullable; +use vortex_dtype::{DType, DecimalDType, NativePType, match_each_native_ptype}; +use vortex_error::{VortexResult, vortex_bail, vortex_err}; +use vortex_scalar::{DecimalScalar, DecimalValue, FromPrimitiveOrF16, Scalar, i256}; use crate::arrays::{ChunkedArray, ChunkedVTable}; use crate::compute::{SumKernel, SumKernelAdapter, sum}; @@ -16,16 +17,23 @@ impl SumKernel for ChunkedVTable { let sum_dtype = Stat::Sum .dtype(array.dtype()) .ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?; - let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive"); - let scalar_value = match_each_native_ptype!( - sum_ptype, - unsigned: |T| { sum_int::(array.chunks())?.into() }, - signed: |T| { sum_int::(array.chunks())?.into() }, - floating: |T| { sum_float(array.chunks())?.into() } - ); + match sum_dtype { + DType::Decimal(decimal_dtype, _) => sum_decimal(array.chunks(), decimal_dtype), + DType::Primitive(sum_ptype, _) => { + let scalar_value = match_each_native_ptype!( + sum_ptype, + unsigned: |T| { sum_int::(array.chunks())?.into() }, + signed: |T| { sum_int::(array.chunks())?.into() }, + floating: |T| { sum_float(array.chunks())?.into() } + ); - Ok(Scalar::new(sum_dtype, scalar_value)) + Ok(Scalar::new(sum_dtype, scalar_value)) + } + _ => { + vortex_bail!("Sum not supported for dtype {}", sum_dtype); + } + } } } @@ -39,7 +47,7 @@ fn sum_int( let chunk_sum = sum(chunk)?; let Some(chunk_sum) = chunk_sum.as_primitive().as_::() else { - // Bail out on overflow + // Bail out missing statistic return Ok(None); }; @@ -63,14 +71,46 @@ fn sum_float(chunks: &[ArrayRef]) -> VortexResult { Ok(result) } +fn sum_decimal(chunks: &[ArrayRef], result_decimal_type: DecimalDType) -> VortexResult { + let mut result = DecimalValue::I256(i256::ZERO); + + let null = || Scalar::null(DType::Decimal(result_decimal_type, Nullable)); + + for chunk in chunks { + let chunk_sum = sum(chunk)?; + + let chunk_decimal = DecimalScalar::try_from(&chunk_sum)?; + let Some(chunk_value) = chunk_decimal.decimal_value() else { + // skips all null chunks + continue; + }; + + // Perform checked addition with current result + let Some(r) = result.checked_add(&chunk_value).filter(|sum_value| { + sum_value + .fits_in_precision(result_decimal_type) + .unwrap_or(false) + }) else { + // Overflow + return Ok(null()); + }; + + result = r; + } + + Ok(Scalar::decimal(result, result_decimal_type, Nullable)) +} + #[cfg(test)] mod tests { - use vortex_dtype::Nullability; - use vortex_scalar::Scalar; + use vortex_buffer::buffer; + use vortex_dtype::{DType, DecimalDType, Nullability}; + use vortex_scalar::{DecimalValue, Scalar, i256}; use crate::array::IntoArray; - use crate::arrays::{ChunkedArray, ConstantArray, PrimitiveArray}; + use crate::arrays::{ChunkedArray, ConstantArray, DecimalArray, PrimitiveArray}; use crate::compute::sum; + use crate::validity::Validity; #[test] fn test_sum_chunked_floats_with_nulls() { @@ -138,4 +178,117 @@ mod tests { let result = sum(chunked.as_ref()).unwrap(); assert_eq!(result.as_primitive().as_::(), Some(36.0)); } + + #[test] + fn test_sum_chunked_decimals() { + // Create decimal chunks with precision=10, scale=2 + let decimal_dtype = DecimalDType::new(10, 2); + let chunk1 = DecimalArray::new( + buffer![100i32, 100i32, 100i32, 100i32, 100i32], + decimal_dtype, + Validity::AllValid, + ); + let chunk2 = DecimalArray::new( + buffer![200i32, 200i32, 200i32], + decimal_dtype, + Validity::AllValid, + ); + let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid); + + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new( + vec![ + chunk1.into_array(), + chunk2.into_array(), + chunk3.into_array(), + ], + dtype, + ) + .unwrap(); + + // Compute sum: 5*100 + 3*200 + 2*300 = 500 + 600 + 600 = 1700 (represents 17.00) + let result = sum(chunked.as_ref()).unwrap(); + let decimal_result = result.as_decimal(); + assert_eq!( + decimal_result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(1700))) + ); + } + + #[test] + fn test_sum_chunked_decimals_with_nulls() { + let decimal_dtype = DecimalDType::new(10, 2); + + // Create chunks with some nulls - all must have same nullability + let chunk1 = DecimalArray::new( + buffer![100i32, 100i32, 100i32], + decimal_dtype, + Validity::AllValid, + ); + let chunk2 = DecimalArray::new( + buffer![0i32, 0i32], + decimal_dtype, + Validity::from_iter([false, false]), + ); + let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid); + + let dtype = chunk1.dtype().clone(); + let chunked = ChunkedArray::try_new( + vec![ + chunk1.into_array(), + chunk2.into_array(), + chunk3.into_array(), + ], + dtype, + ) + .unwrap(); + + // Compute sum: 3*100 + 2*200 = 300 + 400 = 700 (nulls ignored) + let result = sum(chunked.as_ref()).unwrap(); + let decimal_result = result.as_decimal(); + assert_eq!( + decimal_result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(700))) + ); + } + + #[test] + fn test_sum_chunked_decimals_large() { + // Create decimals with precision 3 (max value 999) + // Sum will be 500 + 600 = 1100, which fits in result precision 13 (3+10) + let decimal_dtype = DecimalDType::new(3, 0); + let chunk1 = ConstantArray::new( + Scalar::decimal( + DecimalValue::I16(500), + decimal_dtype, + Nullability::NonNullable, + ), + 1, + ); + let chunk2 = ConstantArray::new( + Scalar::decimal( + DecimalValue::I16(600), + decimal_dtype, + Nullability::NonNullable, + ), + 1, + ); + + let dtype = chunk1.dtype().clone(); + let chunked = + ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap(); + + // Compute sum: 500 + 600 = 1100 + // Result should have precision 13 (3+10), scale 0 + let result = sum(chunked.as_ref()).unwrap(); + let decimal_result = result.as_decimal(); + assert_eq!( + decimal_result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(1100))) + ); + assert_eq!( + result.dtype(), + &DType::Decimal(DecimalDType::new(13, 0), Nullability::Nullable) + ); + } } diff --git a/vortex-array/src/arrays/constant/compute/sum.rs b/vortex-array/src/arrays/constant/compute/sum.rs index c99718dd738..9bb5f9f4496 100644 --- a/vortex-array/src/arrays/constant/compute/sum.rs +++ b/vortex-array/src/arrays/constant/compute/sum.rs @@ -2,9 +2,11 @@ // SPDX-FileCopyrightText: Copyright the Vortex contributors use num_traits::{CheckedMul, ToPrimitive}; -use vortex_dtype::{DType, NativePType, match_each_native_ptype}; -use vortex_error::{VortexResult, vortex_bail, vortex_err}; -use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar, ScalarValue}; +use vortex_dtype::{DType, DecimalDType, NativePType, Nullability, match_each_native_ptype}; +use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err}; +use vortex_scalar::{ + DecimalScalar, DecimalValue, FromPrimitiveOrF16, PrimitiveScalar, Scalar, ScalarValue, i256, +}; use crate::arrays::{ConstantArray, ConstantVTable}; use crate::compute::{SumKernel, SumKernelAdapter}; @@ -36,11 +38,47 @@ fn sum_scalar(scalar: &Scalar, len: usize) -> VortexResult { signed: |T| { sum_integral::(scalar.as_primitive(), len)?.into() }, floating: |T| { sum_float(scalar.as_primitive(), len)?.into() } )), + DType::Decimal(decimal_dtype, _) => sum_decimal(scalar.as_decimal(), len, *decimal_dtype), DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len), dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype), } } +fn sum_decimal( + decimal_scalar: DecimalScalar, + array_len: usize, + decimal_dtype: DecimalDType, +) -> VortexResult { + let result_dtype = Stat::Sum + .dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable)) + .vortex_expect("decimal supports sum"); + let result_decimal_type = result_dtype + .as_decimal_opt() + .vortex_expect("must be decimal"); + + let Some(value) = decimal_scalar.decimal_value() else { + // Null value: return null + return Ok(ScalarValue::null()); + }; + + // Convert array_len to DecimalValue for multiplication + let len_value = DecimalValue::I256(i256::from_i128(array_len as i128)); + + // Multiply value * len + let sum = value.checked_mul(&len_value).and_then(|result| { + // Check if result fits in the precision + result + .fits_in_precision(*result_decimal_type) + .unwrap_or(false) + .then_some(result) + }); + + match sum { + Some(result_value) => Ok(ScalarValue::from(result_value)), + None => Ok(ScalarValue::null()), // Overflow + } +} + fn sum_integral( primitive_scalar: PrimitiveScalar<'_>, array_len: usize, @@ -70,12 +108,13 @@ register_kernel!(SumKernelAdapter(ConstantVTable).lift()); #[cfg(test)] mod tests { - use vortex_dtype::{DType, Nullability, PType}; - use vortex_scalar::Scalar; + use vortex_dtype::{DType, DecimalDType, Nullability, PType}; + use vortex_scalar::{DecimalValue, Scalar}; - use crate::IntoArray; use crate::arrays::ConstantArray; use crate::compute::sum; + use crate::stats::Stat; + use crate::{Array, IntoArray}; #[test] fn test_sum_unsigned() { @@ -123,4 +162,61 @@ mod tests { let result = sum(&array).unwrap(); assert!(result.is_null()); } + + #[test] + fn test_sum_decimal() { + let decimal_dtype = DecimalDType::new(10, 2); + let array = ConstantArray::new( + Scalar::decimal( + DecimalValue::I64(100), + decimal_dtype, + Nullability::NonNullable, + ), + 5, + ) + .into_array(); + + let result = sum(&array).unwrap(); + + assert_eq!( + result.as_decimal().decimal_value(), + Some(DecimalValue::I256(vortex_scalar::i256::from_i128(500))) + ); + assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap()); + } + + #[test] + fn test_sum_decimal_null() { + let decimal_dtype = DecimalDType::new(10, 2); + let array = ConstantArray::new( + Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable)), + 10, + ) + .into_array(); + + let result = sum(&array).unwrap(); + assert!(result.is_null()); + } + + #[test] + fn test_sum_decimal_large_value() { + let decimal_dtype = DecimalDType::new(10, 2); + let array = ConstantArray::new( + Scalar::decimal( + DecimalValue::I64(999_999_999), + decimal_dtype, + Nullability::NonNullable, + ), + 100, + ) + .into_array(); + + let result = sum(&array).unwrap(); + assert_eq!( + result.as_decimal().decimal_value(), + Some(DecimalValue::I256(vortex_scalar::i256::from_i128( + 99_999_999_900 + ))) + ); + } } diff --git a/vortex-array/src/arrays/decimal/compute/sum.rs b/vortex-array/src/arrays/decimal/compute/sum.rs index b9efbf2fd9b..adb81da461b 100644 --- a/vortex-array/src/arrays/decimal/compute/sum.rs +++ b/vortex-array/src/arrays/decimal/compute/sum.rs @@ -4,6 +4,7 @@ use arrow_schema::DECIMAL256_MAX_PRECISION; use num_traits::AsPrimitive; use vortex_dtype::DecimalDType; +use vortex_dtype::Nullability::Nullable; use vortex_error::{VortexResult, vortex_bail}; use vortex_mask::Mask; use vortex_scalar::{DecimalValue, Scalar, match_each_decimal_value_type}; @@ -40,7 +41,6 @@ impl SumKernel for DecimalVTable { #[allow(clippy::cognitive_complexity)] fn sum(&self, array: &DecimalArray) -> VortexResult { let decimal_dtype = array.decimal_dtype(); - let nullability = array.dtype().nullability(); // Both Spark and DataFusion use this heuristic. // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 @@ -60,7 +60,7 @@ impl SumKernel for DecimalVTable { Ok(Scalar::decimal( DecimalValue::from(sum_decimal!(O, array.buffer::())), return_dtype, - nullability, + Nullable, )) }) }) @@ -76,7 +76,7 @@ impl SumKernel for DecimalVTable { mask_values.boolean_buffer() )), return_dtype, - nullability, + Nullable, )) }) }) diff --git a/vortex-array/src/stats/mod.rs b/vortex-array/src/stats/mod.rs index 9093ed91166..204f22a52d0 100644 --- a/vortex-array/src/stats/mod.rs +++ b/vortex-array/src/stats/mod.rs @@ -205,7 +205,7 @@ impl Stat { } }, DType::Extension(ext_dtype) => self.dtype(ext_dtype.storage_dtype())?, - DType::Decimal(decimal_dtype, nullability) => { + DType::Decimal(decimal_dtype, _) => { // Both Spark and DataFusion use this heuristic. // - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 // - https://github.com/apache/datafusion/blob/4153adf2c0f6e317ef476febfdc834208bd46622/datafusion/functions-aggregate/src/sum.rs#L188 @@ -213,7 +213,7 @@ impl Stat { u8::min(DECIMAL256_MAX_PRECISION, decimal_dtype.precision() + 10); DType::Decimal( DecimalDType::new(precision, decimal_dtype.scale()), - *nullability, + Nullable, ) } // Unsupported types diff --git a/vortex-scalar/src/bigint/mod.rs b/vortex-scalar/src/bigint/mod.rs index 3b566c0342d..cccffad0272 100644 --- a/vortex-scalar/src/bigint/mod.rs +++ b/vortex-scalar/src/bigint/mod.rs @@ -4,10 +4,12 @@ mod bigcast; use std::fmt::Display; -use std::ops::{Add, AddAssign, BitOr, Div, Mul, Rem, Shl, Shr, Sub}; +use std::ops::{Add, AddAssign, BitOr, Div, Mul, Neg, Rem, Shl, Shr, Sub}; pub use bigcast::*; -use num_traits::{CheckedAdd, CheckedSub, ConstZero, One, WrappingAdd, WrappingSub, Zero}; +use num_traits::{ + CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, ConstZero, One, WrappingAdd, WrappingSub, Zero, +}; use vortex_error::VortexExpect; /// Signed 256-bit integer type. @@ -69,6 +71,11 @@ impl i256 { Self(self.0.wrapping_pow(exp)) } + /// Raises self to the power of `exp`, wrapping around on overflow. + pub fn checked_pow(&self, exp: u32) -> Option { + self.0.checked_pow(exp).map(Self) + } + /// Wrapping (modular) addition. Computes `self + other`, wrapping around at the boundary. pub fn wrapping_add(&self, other: Self) -> Self { Self(self.0.wrapping_add(other.0)) @@ -121,6 +128,14 @@ impl Sub for i256 { } } +impl Neg for i256 { + type Output = Self; + + fn neg(self) -> Self::Output { + Self(self.0.neg()) + } +} + impl Mul for i256 { type Output = Self; @@ -189,6 +204,18 @@ impl WrappingSub for i256 { } } +impl CheckedMul for i256 { + fn checked_mul(&self, v: &Self) -> Option { + self.0.checked_mul(v.0).map(Self) + } +} + +impl CheckedDiv for i256 { + fn checked_div(&self, v: &Self) -> Option { + self.0.checked_div(v.0).map(Self) + } +} + impl Shr for i256 { type Output = Self; diff --git a/vortex-scalar/src/decimal/scalar.rs b/vortex-scalar/src/decimal/scalar.rs index 38a93561b46..6a890449468 100644 --- a/vortex-scalar/src/decimal/scalar.rs +++ b/vortex-scalar/src/decimal/scalar.rs @@ -6,9 +6,11 @@ use std::fmt; use num_traits::ToPrimitive as NumToPrimitive; use vortex_dtype::{DType, DecimalDType, PType}; -use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err}; +use vortex_error::{VortexError, VortexResult, vortex_bail, vortex_err, vortex_panic}; -use crate::{DecimalValue, InnerScalarValue, Scalar, ScalarValue, match_each_decimal_value}; +use crate::{ + DecimalValue, InnerScalarValue, NumericOperator, Scalar, ScalarValue, match_each_decimal_value, +}; /// A scalar value representing a decimal number with fixed precision and scale. #[derive(Debug, Clone, Copy, Hash)] @@ -164,6 +166,68 @@ impl<'a> DecimalScalar<'a> { ), } } + + /// Apply the (checked) operator to self and other using SQL-style null semantics. + /// + /// If the operation overflows, None is returned. + /// + /// If the types are incompatible (ignoring nullability and precision/scale), an error is returned. + /// + /// If either value is null, the result is null. + /// + /// The result will have the same decimal type (precision/scale) as `self`, and the result + /// is checked to ensure it fits within the precision constraints. + pub fn checked_binary_numeric( + &self, + other: &DecimalScalar<'a>, + op: NumericOperator, + ) -> Option> { + // We could have ops between different types but need to add rules for type inference. + if self.decimal_type != other.decimal_type { + vortex_panic!( + "decimal types must match: {} vs {}", + self.decimal_type, + other.decimal_type + ); + } + + // Use the more nullable dtype as the result type + let result_dtype = if self.dtype.is_nullable() { + self.dtype + } else { + other.dtype + }; + + // Handle null cases using SQL semantics + let result_value = match (self.value, other.value) { + (None, _) | (_, None) => None, + (Some(lhs), Some(rhs)) => { + // Perform the operation + let operation_result = match op { + NumericOperator::Add => lhs.checked_add(&rhs), + NumericOperator::Sub => lhs.checked_sub(&rhs), + NumericOperator::RSub => rhs.checked_sub(&lhs), + NumericOperator::Mul => lhs.checked_mul(&rhs), + NumericOperator::Div => lhs.checked_div(&rhs), + NumericOperator::RDiv => rhs.checked_div(&lhs), + }?; + + // Check if the result fits within the precision constraints + if operation_result.fits_in_precision(self.decimal_type)? { + Some(operation_result) + } else { + // Result exceeds precision, return None (overflow) + return None; + } + } + }; + + Some(DecimalScalar { + dtype: result_dtype, + decimal_type: self.decimal_type, + value: result_value, + }) + } } impl<'a> TryFrom<&'a Scalar> for DecimalScalar<'a> { diff --git a/vortex-scalar/src/decimal/tests.rs b/vortex-scalar/src/decimal/tests.rs index 68ed5b0641a..ced0b5817b1 100644 --- a/vortex-scalar/src/decimal/tests.rs +++ b/vortex-scalar/src/decimal/tests.rs @@ -783,3 +783,218 @@ fn test_decimal_i256_overflow_cast() { let result = decimal.cast(&DType::Primitive(PType::I64, Nullability::NonNullable)); assert!(result.is_err()); } + +// Tests for checked_binary_numeric +#[test] +fn test_decimal_scalar_checked_add() { + use crate::NumericOperator; + + let decimal1 = Scalar::decimal( + DecimalValue::I64(100), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + + let decimal2 = Scalar::decimal( + DecimalValue::I64(200), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + + let result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::Add) + .unwrap(); + assert_eq!( + result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(300))) + ); +} + +#[test] +fn test_decimal_scalar_checked_sub() { + use crate::NumericOperator; + + let decimal1 = Scalar::decimal( + DecimalValue::I64(500), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + + let decimal2 = Scalar::decimal( + DecimalValue::I64(200), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + + let result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::Sub) + .unwrap(); + assert_eq!( + result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(300))) + ); +} + +#[test] +fn test_decimal_scalar_checked_mul() { + use crate::NumericOperator; + + let decimal1 = Scalar::decimal( + DecimalValue::I32(50), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + + let decimal2 = Scalar::decimal( + DecimalValue::I32(10), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + + let result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::Mul) + .unwrap(); + assert_eq!( + result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(500))) + ); +} + +#[test] +fn test_decimal_scalar_checked_div() { + use crate::NumericOperator; + + let decimal1 = Scalar::decimal( + DecimalValue::I64(1000), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + + let decimal2 = Scalar::decimal( + DecimalValue::I64(10), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + + let result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::Div) + .unwrap(); + assert_eq!( + result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(100))) + ); +} + +#[test] +fn test_decimal_scalar_checked_div_by_zero() { + use crate::NumericOperator; + + let decimal1 = Scalar::decimal( + DecimalValue::I64(1000), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + + let decimal2 = Scalar::decimal( + DecimalValue::I64(0), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + + let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Div); + assert_eq!(result, None); +} + +#[test] +fn test_decimal_scalar_null_handling() { + use crate::NumericOperator; + + let decimal1 = Scalar::null(DType::Decimal( + DecimalDType::new(10, 2), + Nullability::Nullable, + )); + let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + + let decimal2 = Scalar::decimal( + DecimalValue::I64(200), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + + let result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::Add) + .unwrap(); + assert_eq!(result.decimal_value(), None); +} + +#[test] +fn test_decimal_scalar_precision_overflow() { + use crate::NumericOperator; + + // Create decimals with precision 3 (max value 999) + let decimal1 = Scalar::decimal( + DecimalValue::I16(999), + DecimalDType::new(3, 0), + Nullability::NonNullable, + ); + let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + + let decimal2 = Scalar::decimal( + DecimalValue::I16(2), + DecimalDType::new(3, 0), + Nullability::NonNullable, + ); + let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + + // 999 + 2 = 1001 which exceeds precision 3 + let result = scalar1.checked_binary_numeric(&scalar2, NumericOperator::Add); + assert_eq!(result, None); +} + +#[test] +fn test_decimal_scalar_rsub_and_rdiv() { + use crate::NumericOperator; + + let decimal1 = Scalar::decimal( + DecimalValue::I64(100), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar1 = DecimalScalar::try_from(&decimal1).unwrap(); + + let decimal2 = Scalar::decimal( + DecimalValue::I64(300), + DecimalDType::new(10, 2), + Nullability::NonNullable, + ); + let scalar2 = DecimalScalar::try_from(&decimal2).unwrap(); + + // RSub: 300 - 100 = 200 + let result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::RSub) + .unwrap(); + assert_eq!( + result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(200))) + ); + + // RDiv: 300 / 100 = 3 + let result = scalar1 + .checked_binary_numeric(&scalar2, NumericOperator::RDiv) + .unwrap(); + assert_eq!( + result.decimal_value(), + Some(DecimalValue::I256(i256::from_i128(3))) + ); +} diff --git a/vortex-scalar/src/decimal/value.rs b/vortex-scalar/src/decimal/value.rs index 3528bb5ac9d..d4a62913858 100644 --- a/vortex-scalar/src/decimal/value.rs +++ b/vortex-scalar/src/decimal/value.rs @@ -7,6 +7,7 @@ use std::cmp::Ordering; use std::fmt; use std::hash::Hash; +use num_traits::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub}; use vortex_dtype::{DType, DecimalDType, Nullability}; use vortex_error::{VortexError, VortexExpect, vortex_err}; @@ -77,6 +78,71 @@ impl DecimalValue { pub fn cast(&self) -> Option { match_each_decimal_value!(self, |value| { T::from(*value) }) } + + /// Check if this decimal value fits within the precision constraints of the given decimal type. + /// + /// The precision defines the total number of significant digits that can be represented. + /// The stored value (regardless of scale) must fit within the range defined by precision. + /// For precision P, the maximum absolute stored value is 10^P - 1. + /// + /// Returns `None` if the value is too large for the precision, `Some(true)` if it fits. + pub fn fits_in_precision(&self, decimal_type: DecimalDType) -> Option { + // Convert to i256 for comparison + let value_i256 = match_each_decimal_value!(self, |v| { + v.to_i256() + .vortex_expect("upcast to i256 must always succeed") + }); + + // Calculate the maximum stored value that can be represented with this precision + // For precision P, the max stored value is 10^P - 1 + // This is independent of scale - scale only affects how we interpret the value + let ten = i256::from_i128(10); + let max_value = ten + .checked_pow(decimal_type.precision() as _) + .vortex_expect("precision must exist in i256"); + let min_value = -max_value; + + Some(value_i256 > min_value && value_i256 < max_value) + } + + /// Helper function to perform a checked binary operation on two decimal values. + /// + /// Both values are upcast to i256 before the operation, and the result is returned as I256. + fn checked_binary_op(&self, other: &Self, op: F) -> Option + where + F: FnOnce(i256, i256) -> Option, + { + let self_upcast = match_each_decimal_value!(self, |v| { + v.to_i256() + .vortex_expect("upcast to i256 must always succeed") + }); + let other_upcast = match_each_decimal_value!(other, |v| { + v.to_i256() + .vortex_expect("upcast to i256 must always succeed") + }); + + op(self_upcast, other_upcast).map(DecimalValue::I256) + } + + /// Checked addition. Returns `None` on overflow. + pub fn checked_add(&self, other: &Self) -> Option { + self.checked_binary_op(other, |a, b| a.checked_add(&b)) + } + + /// Checked subtraction. Returns `None` on overflow. + pub fn checked_sub(&self, other: &Self) -> Option { + self.checked_binary_op(other, |a, b| a.checked_sub(&b)) + } + + /// Checked multiplication. Returns `None` on overflow. + pub fn checked_mul(&self, other: &Self) -> Option { + self.checked_binary_op(other, |a, b| a.checked_mul(&b)) + } + + /// Checked division. Returns `None` on overflow or division by zero. + pub fn checked_div(&self, other: &Self) -> Option { + self.checked_binary_op(other, |a, b| a.checked_div(&b)) + } } // Comparisons between DecimalValue types should upcast to i256 and operate in the upcast space. @@ -315,4 +381,227 @@ mod tests { set.insert(DecimalValue::I256(i256::from_i128(100))); assert_eq!(set.len(), 1); } + + #[test] + fn test_decimal_value_checked_add() { + let a = DecimalValue::I64(100); + let b = DecimalValue::I64(200); + let result = a.checked_add(&b).unwrap(); + assert_eq!(result, DecimalValue::I256(i256::from_i128(300))); + } + + #[test] + fn test_decimal_value_checked_sub() { + let a = DecimalValue::I64(500); + let b = DecimalValue::I64(200); + let result = a.checked_sub(&b).unwrap(); + assert_eq!(result, DecimalValue::I256(i256::from_i128(300))); + } + + #[test] + fn test_decimal_value_checked_mul() { + let a = DecimalValue::I32(50); + let b = DecimalValue::I32(10); + let result = a.checked_mul(&b).unwrap(); + assert_eq!(result, DecimalValue::I256(i256::from_i128(500))); + } + + #[test] + fn test_decimal_value_checked_div() { + let a = DecimalValue::I64(1000); + let b = DecimalValue::I64(10); + let result = a.checked_div(&b).unwrap(); + assert_eq!(result, DecimalValue::I256(i256::from_i128(100))); + } + + #[test] + fn test_decimal_value_checked_div_by_zero() { + let a = DecimalValue::I64(1000); + let b = DecimalValue::I64(0); + let result = a.checked_div(&b); + assert_eq!(result, None); + } + + #[test] + fn test_decimal_value_mixed_types() { + // Test operations with different underlying types + let a = DecimalValue::I8(10); + let b = DecimalValue::I128(20); + let result = a.checked_add(&b).unwrap(); + assert_eq!(result, DecimalValue::I256(i256::from_i128(30))); + } + + #[test] + fn test_fits_in_precision_exact_boundary() { + use vortex_dtype::DecimalDType; + + // Precision 3 means max value is 10^3 - 1 = 999 + let dtype = DecimalDType::new(3, 0); + + // Test exact upper boundary: 999 should fit + let value = DecimalValue::I16(999); + assert_eq!(value.fits_in_precision(dtype), Some(true)); + + // Test just beyond upper boundary: 1000 should NOT fit + let value = DecimalValue::I16(1000); + assert_eq!(value.fits_in_precision(dtype), Some(false)); + + // Test exact lower boundary: -999 should fit + let value = DecimalValue::I16(-999); + assert_eq!(value.fits_in_precision(dtype), Some(true)); + + // Test just beyond lower boundary: -1000 should NOT fit + let value = DecimalValue::I16(-1000); + assert_eq!(value.fits_in_precision(dtype), Some(false)); + } + + #[test] + fn test_fits_in_precision_zero() { + use vortex_dtype::DecimalDType; + + let dtype = DecimalDType::new(5, 2); + + // Zero should always fit + let value = DecimalValue::I8(0); + assert_eq!(value.fits_in_precision(dtype), Some(true)); + } + + #[test] + fn test_fits_in_precision_small_precision() { + use vortex_dtype::DecimalDType; + + // Precision 1 means max value is 10^1 - 1 = 9 + let dtype = DecimalDType::new(1, 0); + + // Test values within range + for i in -9..=9 { + let value = DecimalValue::I8(i); + assert_eq!( + value.fits_in_precision(dtype), + Some(true), + "value {} should fit in precision 1", + i + ); + } + + // Test values outside range + let value = DecimalValue::I8(10); + assert_eq!(value.fits_in_precision(dtype), Some(false)); + let value = DecimalValue::I8(-10); + assert_eq!(value.fits_in_precision(dtype), Some(false)); + } + + #[test] + fn test_fits_in_precision_large_precision() { + use vortex_dtype::DecimalDType; + + // Precision 38 means max value is 10^38 - 1 + let dtype = DecimalDType::new(38, 0); + + // Test i128::MAX which is approximately 1.7e38 + // This should NOT fit because 10^38 - 1 < i128::MAX + let value = DecimalValue::I128(i128::MAX); + assert_eq!(value.fits_in_precision(dtype), Some(false)); + + // Test a large value that should fit: 10^37 + let value = DecimalValue::I128(10_i128.pow(37)); + assert_eq!(value.fits_in_precision(dtype), Some(true)); + + // Test 10^38 - 1 (the exact maximum) + let max_val = i256::from_i128(10).wrapping_pow(38) - i256::from_i128(1); + let value = DecimalValue::I256(max_val); + assert_eq!(value.fits_in_precision(dtype), Some(true)); + + // Test 10^38 (just over the maximum) + let over_max = i256::from_i128(10).wrapping_pow(38); + let value = DecimalValue::I256(over_max); + assert_eq!(value.fits_in_precision(dtype), Some(false)); + } + + #[test] + fn test_fits_in_precision_max_precision() { + use vortex_dtype::DecimalDType; + + // Maximum precision is 76 + let dtype = DecimalDType::new(76, 0); + + // Test that reasonable i256 values fit + let value = DecimalValue::I256(i256::from_i128(i128::MAX)); + assert_eq!(value.fits_in_precision(dtype), Some(true)); + + // Test negative + let value = DecimalValue::I256(i256::from_i128(i128::MIN)); + assert_eq!(value.fits_in_precision(dtype), Some(true)); + } + + #[test] + fn test_fits_in_precision_different_scales() { + use vortex_dtype::DecimalDType; + + // Scale doesn't affect the precision check - it's only about the stored value + let value = DecimalValue::I32(12345); + + // Precision 5 with different scales + assert_eq!(value.fits_in_precision(DecimalDType::new(5, 0)), Some(true)); + assert_eq!(value.fits_in_precision(DecimalDType::new(5, 2)), Some(true)); + assert_eq!( + value.fits_in_precision(DecimalDType::new(5, -2)), + Some(true) + ); + + // Precision 4 should fail (max value 9999, we have 12345) + assert_eq!( + value.fits_in_precision(DecimalDType::new(4, 0)), + Some(false) + ); + assert_eq!( + value.fits_in_precision(DecimalDType::new(4, 2)), + Some(false) + ); + } + + #[test] + fn test_fits_in_precision_negative_values() { + use vortex_dtype::DecimalDType; + + let dtype = DecimalDType::new(4, 2); + + // Test negative values at boundaries + // Precision 4 means max magnitude is 9999 + let value = DecimalValue::I16(-9999); + assert_eq!(value.fits_in_precision(dtype), Some(true)); + + let value = DecimalValue::I16(-10000); + assert_eq!(value.fits_in_precision(dtype), Some(false)); + + let value = DecimalValue::I16(-1); + assert_eq!(value.fits_in_precision(dtype), Some(true)); + } + + #[test] + fn test_fits_in_precision_mixed_decimal_value_types() { + use vortex_dtype::DecimalDType; + + let dtype = DecimalDType::new(5, 0); + + // Test that different DecimalValue types work correctly + assert_eq!(DecimalValue::I8(99).fits_in_precision(dtype), Some(true)); + assert_eq!(DecimalValue::I16(9999).fits_in_precision(dtype), Some(true)); + assert_eq!( + DecimalValue::I32(99999).fits_in_precision(dtype), + Some(true) + ); + assert_eq!( + DecimalValue::I64(100000).fits_in_precision(dtype), + Some(false) + ); + assert_eq!( + DecimalValue::I128(99999).fits_in_precision(dtype), + Some(true) + ); + assert_eq!( + DecimalValue::I256(i256::from_i128(100000)).fits_in_precision(dtype), + Some(false) + ); + } }