Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Decimal256 on AVG aggregate expression #7853

Merged
merged 5 commits into from
Oct 19, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 72 additions & 22 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use arrow::array::{AsArray, PrimitiveBuilder};
use log::debug;

use std::any::Any;
use std::fmt::Debug;
use std::sync::Arc;

use crate::aggregate::groups_accumulator::accumulate::NullState;
Expand All @@ -33,15 +34,17 @@ use arrow::{
array::{ArrayRef, UInt64Array},
datatypes::Field,
};
use arrow_array::types::{Decimal256Type, DecimalType};
use arrow_array::{
Array, ArrowNativeTypeOp, ArrowNumericType, ArrowPrimitiveType, PrimitiveArray,
};
use arrow_buffer::{i256, ArrowNativeType};
use datafusion_common::{not_impl_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::type_coercion::aggregates::avg_return_type;
use datafusion_expr::Accumulator;

use super::groups_accumulator::EmitTo;
use super::utils::Decimal128Averager;
use super::utils::DecimalAverager;

/// AVG aggregate expression
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -88,7 +91,19 @@ impl AggregateExpr for Avg {
(
Decimal128(sum_precision, sum_scale),
Decimal128(target_precision, target_scale),
) => Ok(Box::new(DecimalAvgAccumulator {
) => Ok(Box::new(DecimalAvgAccumulator::<Decimal128Type> {
sum: None,
count: 0,
sum_scale: *sum_scale,
sum_precision: *sum_precision,
target_precision: *target_precision,
target_scale: *target_scale,
})),

(
Decimal256(sum_precision, sum_scale),
Decimal256(target_precision, target_scale),
) => Ok(Box::new(DecimalAvgAccumulator::<Decimal256Type> {
sum: None,
count: 0,
sum_scale: *sum_scale,
Expand Down Expand Up @@ -156,7 +171,7 @@ impl AggregateExpr for Avg {
Decimal128(_sum_precision, sum_scale),
Decimal128(target_precision, target_scale),
) => {
let decimal_averager = Decimal128Averager::try_new(
let decimal_averager = DecimalAverager::<Decimal128Type>::try_new(
*sum_scale,
*target_precision,
*target_scale,
Expand All @@ -172,6 +187,27 @@ impl AggregateExpr for Avg {
)))
}

(
Decimal256(_sum_precision, sum_scale),
Decimal256(target_precision, target_scale),
) => {
let decimal_averager = DecimalAverager::<Decimal256Type>::try_new(
*sum_scale,
*target_precision,
*target_scale,
)?;

let avg_fn = move |sum: i256, count: u64| {
decimal_averager.avg(sum, i256::from_usize(count as usize).unwrap())
};

Ok(Box::new(AvgGroupsAccumulator::<Decimal256Type, _>::new(
&self.input_data_type,
&self.result_data_type,
avg_fn,
)))
}

_ => not_impl_err!(
"AvgGroupsAccumulator for ({} --> {})",
self.input_data_type,
Expand Down Expand Up @@ -256,40 +292,55 @@ impl Accumulator for AvgAccumulator {
}

/// An accumulator to compute the average for decimals
#[derive(Debug)]
struct DecimalAvgAccumulator {
sum: Option<i128>,
struct DecimalAvgAccumulator<T: DecimalType + ArrowNumericType> {
sum: Option<T::Native>,
count: u64,
sum_scale: i8,
sum_precision: u8,
target_precision: u8,
target_scale: i8,
}

impl Accumulator for DecimalAvgAccumulator {
impl<T: DecimalType + ArrowNumericType> Debug for DecimalAvgAccumulator<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DecimalAvgAccumulator")
.field("sum", &self.sum)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed because #derive doesn't know how to handle Option<T::Native>?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea. #derive doesn't work.

.field("count", &self.count)
.field("sum_scale", &self.sum_scale)
.field("sum_precision", &self.sum_precision)
.field("target_precision", &self.target_precision)
.field("target_scale", &self.target_scale)
.finish()
}
}

impl<T: DecimalType + ArrowNumericType> Accumulator for DecimalAvgAccumulator<T> {
fn state(&self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale),
ScalarValue::new_primitive::<T>(
self.sum,
&T::TYPE_CONSTRUCTOR(self.sum_precision, self.sum_scale),
)?,
])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<Decimal128Type>();
let values = values[0].as_primitive::<T>();

self.count += (values.len() - values.null_count()) as u64;
if let Some(x) = sum(values) {
let v = self.sum.get_or_insert(0);
*v += x;
let v = self.sum.get_or_insert(T::Native::default());
self.sum = Some(v.add_wrapping(x));
}
Ok(())
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = values[0].as_primitive::<Decimal128Type>();
let values = values[0].as_primitive::<T>();
self.count -= (values.len() - values.null_count()) as u64;
if let Some(x) = sum(values) {
self.sum = Some(self.sum.unwrap() - x);
self.sum = Some(self.sum.unwrap().sub_wrapping(x));
}
Ok(())
}
Expand All @@ -299,9 +350,9 @@ impl Accumulator for DecimalAvgAccumulator {
self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();

// sums are summed
if let Some(x) = sum(states[1].as_primitive::<Decimal128Type>()) {
let v = self.sum.get_or_insert(0);
*v += x;
if let Some(x) = sum(states[1].as_primitive::<T>()) {
let v = self.sum.get_or_insert(T::Native::default());
self.sum = Some(v.add_wrapping(x));
}
Ok(())
}
Expand All @@ -310,20 +361,19 @@ impl Accumulator for DecimalAvgAccumulator {
let v = self
.sum
.map(|v| {
Decimal128Averager::try_new(
DecimalAverager::<T>::try_new(
self.sum_scale,
self.target_precision,
self.target_scale,
)?
.avg(v, self.count as _)
.avg(v, T::Native::from_usize(self.count as usize).unwrap())
})
.transpose()?;

Ok(ScalarValue::Decimal128(
ScalarValue::new_primitive::<T>(
v,
self.target_precision,
self.target_scale,
))
&T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale),
)
}
fn supports_retract_batch(&self) -> bool {
true
Expand Down
59 changes: 34 additions & 25 deletions datafusion/physical-expr/src/aggregate/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@

use crate::{AggregateExpr, PhysicalSortExpr};
use arrow::array::ArrayRef;
use arrow::datatypes::{MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION};
use arrow_array::cast::AsArray;
use arrow_array::types::{
Decimal128Type, TimestampMicrosecondType, TimestampMillisecondType,
Decimal128Type, DecimalType, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType,
};
use arrow_array::ArrowNativeTypeOp;
use arrow_buffer::ArrowNativeType;
use arrow_schema::{DataType, Field};
use datafusion_common::{exec_err, DataFusionError, Result};
use datafusion_expr::Accumulator;
Expand All @@ -42,27 +43,25 @@ pub fn get_accum_scalar_values_as_arrays(
.collect::<Vec<_>>())
}

/// Computes averages for `Decimal128` values, checking for overflow
/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow
///
/// This is needed because different precisions for Decimal128 can
/// This is needed because different precisions for Decimal128/Decimal256 can
/// store different ranges of values and thus sum/count may not fit in
/// the target type.
///
/// For example, the precision is 3, the max of value is `999` and the min
/// value is `-999`
pub(crate) struct Decimal128Averager {
pub(crate) struct DecimalAverager<T: DecimalType> {
/// scale factor for sum values (10^sum_scale)
sum_mul: i128,
sum_mul: T::Native,
/// scale factor for target (10^target_scale)
target_mul: i128,
/// The minimum output value possible to represent with the target precision
target_min: i128,
/// The maximum output value possible to represent with the target precision
target_max: i128,
target_mul: T::Native,
/// the output precision
target_precision: u8,
}

impl Decimal128Averager {
/// Create a new `Decimal128Averager`:
impl<T: DecimalType> DecimalAverager<T> {
/// Create a new `DecimalAverager`:
///
/// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
/// * target_precision: the output precision
Expand All @@ -74,35 +73,45 @@ impl Decimal128Averager {
target_precision: u8,
target_scale: i8,
) -> Result<Self> {
let sum_mul = 10_i128.pow(sum_scale as u32);
let target_mul = 10_i128.pow(target_scale as u32);
let target_min = MIN_DECIMAL_FOR_EACH_PRECISION[target_precision as usize - 1];
let target_max = MAX_DECIMAL_FOR_EACH_PRECISION[target_precision as usize - 1];
let sum_mul = T::Native::from_usize(10 as usize)
.map(|b| b.pow_wrapping(sum_scale as u32))
.ok_or(DataFusionError::Internal(
"Failed to compute sum_mul in DecimalAverager".to_string(),
))?;

let target_mul = T::Native::from_usize(10 as usize)
.map(|b| b.pow_wrapping(target_scale as u32))
.ok_or(DataFusionError::Internal(
"Failed to compute target_mul in DecimalAverager".to_string(),
))?;

if target_mul >= sum_mul {
Ok(Self {
sum_mul,
target_mul,
target_min,
target_max,
target_precision,
})
} else {
// can't convert the lit decimal to the returned data type
exec_err!("Arithmetic Overflow in AvgAccumulator")
}
}

/// Returns the `sum`/`count` as a i128 Decimal128 with
/// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with
/// target_scale and target_precision and reporting overflow.
///
/// * sum: The total sum value stored as Decimal128 with sum_scale
/// (passed to `Self::try_new`)
/// * count: total count, stored as a i128 (*NOT* a Decimal128 value)
/// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value)
#[inline(always)]
pub fn avg(&self, sum: i128, count: i128) -> Result<i128> {
if let Some(value) = sum.checked_mul(self.target_mul / self.sum_mul) {
let new_value = value / count;
if new_value >= self.target_min && new_value <= self.target_max {
pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
let new_value = value.div_wrapping(count);

let validate =
T::validate_decimal_precision(new_value, self.target_precision);

if validate.is_ok() {
Ok(new_value)
} else {
exec_err!("Arithmetic Overflow in AvgAccumulator")
Expand Down
4 changes: 3 additions & 1 deletion datafusion/sqllogictest/test_files/decimal.slt
Original file line number Diff line number Diff line change
Expand Up @@ -622,8 +622,10 @@ create table t as values (arrow_cast(123, 'Decimal256(5,2)'));
statement ok
set datafusion.execution.target_partitions = 1;

query error DataFusion error: This feature is not implemented: AvgAccumulator for \(Decimal256\(5, 2\) --> Decimal256\(9, 6\)\)
query R
select AVG(column1) from t;
----
123

statement ok
drop table t;