Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 168 additions & 15 deletions vortex-array/src/arrays/chunked/compute/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use num_traits::PrimInt;
use vortex_dtype::{NativePType, PType, match_each_native_ptype};
use vortex_error::{VortexExpect, VortexResult, vortex_err};
use vortex_scalar::{FromPrimitiveOrF16, Scalar};
use vortex_dtype::Nullability::Nullable;
use vortex_dtype::{DType, DecimalDType, NativePType, match_each_native_ptype};
use vortex_error::{VortexResult, vortex_bail, vortex_err};
use vortex_scalar::{DecimalScalar, DecimalValue, FromPrimitiveOrF16, Scalar, i256};

use crate::arrays::{ChunkedArray, ChunkedVTable};
use crate::compute::{SumKernel, SumKernelAdapter, sum};
Expand All @@ -16,16 +17,23 @@ impl SumKernel for ChunkedVTable {
let sum_dtype = Stat::Sum
.dtype(array.dtype())
.ok_or_else(|| vortex_err!("Sum not supported for dtype {}", array.dtype()))?;
let sum_ptype = PType::try_from(&sum_dtype).vortex_expect("sum dtype must be primitive");

let scalar_value = match_each_native_ptype!(
sum_ptype,
unsigned: |T| { sum_int::<u64>(array.chunks())?.into() },
signed: |T| { sum_int::<i64>(array.chunks())?.into() },
floating: |T| { sum_float(array.chunks())?.into() }
);
match sum_dtype {
DType::Decimal(decimal_dtype, _) => sum_decimal(array.chunks(), decimal_dtype),
DType::Primitive(sum_ptype, _) => {
let scalar_value = match_each_native_ptype!(
sum_ptype,
unsigned: |T| { sum_int::<u64>(array.chunks())?.into() },
signed: |T| { sum_int::<i64>(array.chunks())?.into() },
floating: |T| { sum_float(array.chunks())?.into() }
);

Ok(Scalar::new(sum_dtype, scalar_value))
Ok(Scalar::new(sum_dtype, scalar_value))
}
_ => {
vortex_bail!("Sum not supported for dtype {}", sum_dtype);
}
}
}
}

Expand All @@ -39,7 +47,7 @@ fn sum_int<T: NativePType + PrimInt + FromPrimitiveOrF16>(
let chunk_sum = sum(chunk)?;

let Some(chunk_sum) = chunk_sum.as_primitive().as_::<T>() else {
// Bail out on overflow
// Bail out missing statistic
return Ok(None);
};

Expand All @@ -63,14 +71,46 @@ fn sum_float(chunks: &[ArrayRef]) -> VortexResult<f64> {
Ok(result)
}

fn sum_decimal(chunks: &[ArrayRef], result_decimal_type: DecimalDType) -> VortexResult<Scalar> {
let mut result = DecimalValue::I256(i256::ZERO);

let null = || Scalar::null(DType::Decimal(result_decimal_type, Nullable));

for chunk in chunks {
let chunk_sum = sum(chunk)?;

let chunk_decimal = DecimalScalar::try_from(&chunk_sum)?;
let Some(chunk_value) = chunk_decimal.decimal_value() else {
// skips all null chunks
continue;
};

// Perform checked addition with current result
let Some(r) = result.checked_add(&chunk_value).filter(|sum_value| {
sum_value
.fits_in_precision(result_decimal_type)
.unwrap_or(false)
}) else {
// Overflow
return Ok(null());
};

result = r;
}

Ok(Scalar::decimal(result, result_decimal_type, Nullable))
}

#[cfg(test)]
mod tests {
use vortex_dtype::Nullability;
use vortex_scalar::Scalar;
use vortex_buffer::buffer;
use vortex_dtype::{DType, DecimalDType, Nullability};
use vortex_scalar::{DecimalValue, Scalar, i256};

use crate::array::IntoArray;
use crate::arrays::{ChunkedArray, ConstantArray, PrimitiveArray};
use crate::arrays::{ChunkedArray, ConstantArray, DecimalArray, PrimitiveArray};
use crate::compute::sum;
use crate::validity::Validity;

#[test]
fn test_sum_chunked_floats_with_nulls() {
Expand Down Expand Up @@ -138,4 +178,117 @@ mod tests {
let result = sum(chunked.as_ref()).unwrap();
assert_eq!(result.as_primitive().as_::<f64>(), Some(36.0));
}

#[test]
fn test_sum_chunked_decimals() {
// Create decimal chunks with precision=10, scale=2
let decimal_dtype = DecimalDType::new(10, 2);
let chunk1 = DecimalArray::new(
buffer![100i32, 100i32, 100i32, 100i32, 100i32],
decimal_dtype,
Validity::AllValid,
);
let chunk2 = DecimalArray::new(
buffer![200i32, 200i32, 200i32],
decimal_dtype,
Validity::AllValid,
);
let chunk3 = DecimalArray::new(buffer![300i32, 300i32], decimal_dtype, Validity::AllValid);

let dtype = chunk1.dtype().clone();
let chunked = ChunkedArray::try_new(
vec![
chunk1.into_array(),
chunk2.into_array(),
chunk3.into_array(),
],
dtype,
)
.unwrap();

// Compute sum: 5*100 + 3*200 + 2*300 = 500 + 600 + 600 = 1700 (represents 17.00)
let result = sum(chunked.as_ref()).unwrap();
let decimal_result = result.as_decimal();
assert_eq!(
decimal_result.decimal_value(),
Some(DecimalValue::I256(i256::from_i128(1700)))
);
}

#[test]
fn test_sum_chunked_decimals_with_nulls() {
let decimal_dtype = DecimalDType::new(10, 2);

// Create chunks with some nulls - all must have same nullability
let chunk1 = DecimalArray::new(
buffer![100i32, 100i32, 100i32],
decimal_dtype,
Validity::AllValid,
);
let chunk2 = DecimalArray::new(
buffer![0i32, 0i32],
decimal_dtype,
Validity::from_iter([false, false]),
);
let chunk3 = DecimalArray::new(buffer![200i32, 200i32], decimal_dtype, Validity::AllValid);

let dtype = chunk1.dtype().clone();
let chunked = ChunkedArray::try_new(
vec![
chunk1.into_array(),
chunk2.into_array(),
chunk3.into_array(),
],
dtype,
)
.unwrap();

// Compute sum: 3*100 + 2*200 = 300 + 400 = 700 (nulls ignored)
let result = sum(chunked.as_ref()).unwrap();
let decimal_result = result.as_decimal();
assert_eq!(
decimal_result.decimal_value(),
Some(DecimalValue::I256(i256::from_i128(700)))
);
}

#[test]
fn test_sum_chunked_decimals_large() {
// Create decimals with precision 3 (max value 999)
// Sum will be 500 + 600 = 1100, which fits in result precision 13 (3+10)
let decimal_dtype = DecimalDType::new(3, 0);
let chunk1 = ConstantArray::new(
Scalar::decimal(
DecimalValue::I16(500),
decimal_dtype,
Nullability::NonNullable,
),
1,
);
let chunk2 = ConstantArray::new(
Scalar::decimal(
DecimalValue::I16(600),
decimal_dtype,
Nullability::NonNullable,
),
1,
);

let dtype = chunk1.dtype().clone();
let chunked =
ChunkedArray::try_new(vec![chunk1.into_array(), chunk2.into_array()], dtype).unwrap();

// Compute sum: 500 + 600 = 1100
// Result should have precision 13 (3+10), scale 0
let result = sum(chunked.as_ref()).unwrap();
let decimal_result = result.as_decimal();
assert_eq!(
decimal_result.decimal_value(),
Some(DecimalValue::I256(i256::from_i128(1100)))
);
assert_eq!(
result.dtype(),
&DType::Decimal(DecimalDType::new(13, 0), Nullability::Nullable)
);
}
}
108 changes: 102 additions & 6 deletions vortex-array/src/arrays/constant/compute/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use num_traits::{CheckedMul, ToPrimitive};
use vortex_dtype::{DType, NativePType, match_each_native_ptype};
use vortex_error::{VortexResult, vortex_bail, vortex_err};
use vortex_scalar::{FromPrimitiveOrF16, PrimitiveScalar, Scalar, ScalarValue};
use vortex_dtype::{DType, DecimalDType, NativePType, Nullability, match_each_native_ptype};
use vortex_error::{VortexExpect, VortexResult, vortex_bail, vortex_err};
use vortex_scalar::{
DecimalScalar, DecimalValue, FromPrimitiveOrF16, PrimitiveScalar, Scalar, ScalarValue, i256,
};

use crate::arrays::{ConstantArray, ConstantVTable};
use crate::compute::{SumKernel, SumKernelAdapter};
Expand Down Expand Up @@ -36,11 +38,47 @@ fn sum_scalar(scalar: &Scalar, len: usize) -> VortexResult<ScalarValue> {
signed: |T| { sum_integral::<i64>(scalar.as_primitive(), len)?.into() },
floating: |T| { sum_float(scalar.as_primitive(), len)?.into() }
)),
DType::Decimal(decimal_dtype, _) => sum_decimal(scalar.as_decimal(), len, *decimal_dtype),
DType::Extension(_) => sum_scalar(&scalar.as_extension().storage(), len),
dtype => vortex_bail!("Unsupported dtype for sum: {}", dtype),
}
}

fn sum_decimal(
decimal_scalar: DecimalScalar,
array_len: usize,
decimal_dtype: DecimalDType,
) -> VortexResult<ScalarValue> {
let result_dtype = Stat::Sum
.dtype(&DType::Decimal(decimal_dtype, Nullability::Nullable))
.vortex_expect("decimal supports sum");
let result_decimal_type = result_dtype
.as_decimal_opt()
.vortex_expect("must be decimal");

let Some(value) = decimal_scalar.decimal_value() else {
// Null value: return null
return Ok(ScalarValue::null());
};

// Convert array_len to DecimalValue for multiplication
let len_value = DecimalValue::I256(i256::from_i128(array_len as i128));

// Multiply value * len
let sum = value.checked_mul(&len_value).and_then(|result| {
// Check if result fits in the precision
result
.fits_in_precision(*result_decimal_type)
.unwrap_or(false)
.then_some(result)
});

match sum {
Some(result_value) => Ok(ScalarValue::from(result_value)),
None => Ok(ScalarValue::null()), // Overflow
}
}

fn sum_integral<T>(
primitive_scalar: PrimitiveScalar<'_>,
array_len: usize,
Expand Down Expand Up @@ -70,12 +108,13 @@ register_kernel!(SumKernelAdapter(ConstantVTable).lift());

#[cfg(test)]
mod tests {
use vortex_dtype::{DType, Nullability, PType};
use vortex_scalar::Scalar;
use vortex_dtype::{DType, DecimalDType, Nullability, PType};
use vortex_scalar::{DecimalValue, Scalar};

use crate::IntoArray;
use crate::arrays::ConstantArray;
use crate::compute::sum;
use crate::stats::Stat;
use crate::{Array, IntoArray};

#[test]
fn test_sum_unsigned() {
Expand Down Expand Up @@ -123,4 +162,61 @@ mod tests {
let result = sum(&array).unwrap();
assert!(result.is_null());
}

#[test]
fn test_sum_decimal() {
let decimal_dtype = DecimalDType::new(10, 2);
let array = ConstantArray::new(
Scalar::decimal(
DecimalValue::I64(100),
decimal_dtype,
Nullability::NonNullable,
),
5,
)
.into_array();

let result = sum(&array).unwrap();

assert_eq!(
result.as_decimal().decimal_value(),
Some(DecimalValue::I256(vortex_scalar::i256::from_i128(500)))
);
assert_eq!(result.dtype(), &Stat::Sum.dtype(array.dtype()).unwrap());
}

#[test]
fn test_sum_decimal_null() {
let decimal_dtype = DecimalDType::new(10, 2);
let array = ConstantArray::new(
Scalar::null(DType::Decimal(decimal_dtype, Nullability::Nullable)),
10,
)
.into_array();

let result = sum(&array).unwrap();
assert!(result.is_null());
}

#[test]
fn test_sum_decimal_large_value() {
let decimal_dtype = DecimalDType::new(10, 2);
let array = ConstantArray::new(
Scalar::decimal(
DecimalValue::I64(999_999_999),
decimal_dtype,
Nullability::NonNullable,
),
100,
)
.into_array();

let result = sum(&array).unwrap();
assert_eq!(
result.as_decimal().decimal_value(),
Some(DecimalValue::I256(vortex_scalar::i256::from_i128(
99_999_999_900
)))
);
}
}
6 changes: 3 additions & 3 deletions vortex-array/src/arrays/decimal/compute/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
use arrow_schema::DECIMAL256_MAX_PRECISION;
use num_traits::AsPrimitive;
use vortex_dtype::DecimalDType;
use vortex_dtype::Nullability::Nullable;
use vortex_error::{VortexResult, vortex_bail};
use vortex_mask::Mask;
use vortex_scalar::{DecimalValue, Scalar, match_each_decimal_value_type};
Expand Down Expand Up @@ -40,7 +41,6 @@ impl SumKernel for DecimalVTable {
#[allow(clippy::cognitive_complexity)]
fn sum(&self, array: &DecimalArray) -> VortexResult<Scalar> {
let decimal_dtype = array.decimal_dtype();
let nullability = array.dtype().nullability();

// Both Spark and DataFusion use this heuristic.
// - https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
Expand All @@ -60,7 +60,7 @@ impl SumKernel for DecimalVTable {
Ok(Scalar::decimal(
DecimalValue::from(sum_decimal!(O, array.buffer::<I>())),
return_dtype,
nullability,
Nullable,
))
})
})
Expand All @@ -76,7 +76,7 @@ impl SumKernel for DecimalVTable {
mask_values.boolean_buffer()
)),
return_dtype,
nullability,
Nullable,
))
})
})
Expand Down
Loading
Loading