diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs index 3d6889431d61..56cdaf6618de 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +mod decimal; mod numeric; +pub use decimal::DecimalDistinctAvgAccumulator; pub use numeric::Float64DistinctAvgAccumulator; diff --git a/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs new file mode 100644 index 000000000000..a71871b9b41e --- /dev/null +++ b/datafusion/functions-aggregate-common/src/aggregate/avg_distinct/decimal.rs @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{ArrayRef, ArrowNumericType}, + datatypes::{i256, Decimal128Type, Decimal256Type, DecimalType}, +}; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr_common::accumulator::Accumulator; +use std::fmt::Debug; +use std::mem::size_of_val; + +use crate::aggregate::sum_distinct::DistinctSumAccumulator; +use crate::utils::DecimalAverager; + +/// Generic implementation of `AVG DISTINCT` for Decimal types. +/// Handles both Decimal128Type and Decimal256Type. +#[derive(Debug)] +pub struct DecimalDistinctAvgAccumulator { + sum_accumulator: DistinctSumAccumulator, + sum_scale: i8, + target_precision: u8, + target_scale: i8, +} + +impl DecimalDistinctAvgAccumulator { + pub fn with_decimal_params( + sum_scale: i8, + target_precision: u8, + target_scale: i8, + ) -> Self { + let data_type = T::TYPE_CONSTRUCTOR(T::MAX_PRECISION, sum_scale); + + Self { + sum_accumulator: DistinctSumAccumulator::new(&data_type), + sum_scale, + target_precision, + target_scale, + } + } +} + +impl Accumulator + for DecimalDistinctAvgAccumulator +{ + fn state(&mut self) -> Result> { + self.sum_accumulator.state() + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.sum_accumulator.update_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.sum_accumulator.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + if self.sum_accumulator.distinct_count() == 0 { + return ScalarValue::new_primitive::( + None, + &T::TYPE_CONSTRUCTOR(self.target_precision, self.target_scale), + ); + } + + let sum_scalar = self.sum_accumulator.evaluate()?; + + match sum_scalar { + ScalarValue::Decimal128(Some(sum), _, _) => { + let decimal_averager = DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )?; + let avg = decimal_averager + .avg(sum, self.sum_accumulator.distinct_count() as i128)?; + Ok(ScalarValue::Decimal128( + Some(avg), + self.target_precision, + self.target_scale, + )) + } + ScalarValue::Decimal256(Some(sum), _, _) => { + let decimal_averager = DecimalAverager::::try_new( + self.sum_scale, + self.target_precision, + self.target_scale, + )?; + // `distinct_count` returns `u64`, but `avg` expects `i256` + // first convert `u64` to `i128`, then convert `i128` to `i256` to avoid overflow + let distinct_cnt: i128 = self.sum_accumulator.distinct_count() as i128; + let count: i256 = i256::from_i128(distinct_cnt); + let avg = decimal_averager.avg(sum, count)?; + Ok(ScalarValue::Decimal256( + Some(avg), + self.target_precision, + self.target_scale, + )) + } + + _ => unreachable!("Unsupported decimal type: {:?}", sum_scalar), + } + } + + fn size(&self) -> usize { + let fixed_size = size_of_val(self); + + // Account for the size of the sum_accumulator with its contained values + fixed_size + self.sum_accumulator.size() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::{Decimal128Array, Decimal256Array}; + use std::sync::Arc; + + #[test] + fn test_decimal128_distinct_avg_accumulator() -> Result<()> { + let precision = 10_u8; + let scale = 4_i8; + let array = Decimal128Array::from(vec![ + Some(100_0000), + Some(125_0000), + Some(175_0000), + Some(200_0000), + Some(200_0000), + Some(300_0000), + None, + None, + ]) + .with_precision_and_scale(precision, scale)?; + + let mut accumulator = + DecimalDistinctAvgAccumulator::::with_decimal_params( + scale, 14, 8, + ); + accumulator.update_batch(&[Arc::new(array)])?; + + let result = accumulator.evaluate()?; + let expected_result = ScalarValue::Decimal128(Some(180_00000000), 14, 8); + assert_eq!(result, expected_result); + + Ok(()) + } + + #[test] + fn test_decimal256_distinct_avg_accumulator() -> Result<()> { + let precision = 50_u8; + let scale = 2_i8; + + let array = Decimal256Array::from(vec![ + Some(i256::from_i128(10_000)), + Some(i256::from_i128(12_500)), + Some(i256::from_i128(17_500)), + Some(i256::from_i128(20_000)), + Some(i256::from_i128(20_000)), + Some(i256::from_i128(30_000)), + None, + None, + ]) + .with_precision_and_scale(precision, scale)?; + + let mut accumulator = + DecimalDistinctAvgAccumulator::::with_decimal_params( + scale, 54, 6, + ); + accumulator.update_batch(&[Arc::new(array)])?; + + let result = accumulator.evaluate()?; + let expected_result = + ScalarValue::Decimal256(Some(i256::from_i128(180_000000)), 54, 6); + assert_eq!(result, expected_result); + + Ok(()) + } +} diff --git a/datafusion/functions-aggregate/src/average.rs b/datafusion/functions-aggregate/src/average.rs index f7cb74fd55a2..5223ef533603 100644 --- a/datafusion/functions-aggregate/src/average.rs +++ b/datafusion/functions-aggregate/src/average.rs @@ -27,6 +27,7 @@ use arrow::datatypes::{ i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, DurationSecondType, Field, FieldRef, Float64Type, TimeUnit, UInt64Type, + DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; use datafusion_common::{ exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue, @@ -40,7 +41,9 @@ use datafusion_expr::{ ReversedUDAF, Signature, }; -use datafusion_functions_aggregate_common::aggregate::avg_distinct::Float64DistinctAvgAccumulator; +use datafusion_functions_aggregate_common::aggregate::avg_distinct::{ + DecimalDistinctAvgAccumulator, Float64DistinctAvgAccumulator, +}; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::accumulate::NullState; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::{ filtered_null_mask, set_nulls, @@ -120,13 +123,36 @@ impl AggregateUDFImpl for Avg { // instantiate specialized accumulator based for the type if acc_args.is_distinct { - match &data_type { + match (&data_type, acc_args.return_type()) { // Numeric types are converted to Float64 via `coerce_avg_type` during logical plan creation - Float64 => Ok(Box::new(Float64DistinctAvgAccumulator::default())), - _ => exec_err!("AVG(DISTINCT) for {} not supported", data_type), + (Float64, _) => Ok(Box::new(Float64DistinctAvgAccumulator::default())), + + ( + Decimal128(_, scale), + Decimal128(target_precision, target_scale), + ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))), + + ( + Decimal256(_, scale), + Decimal256(target_precision, target_scale), + ) => Ok(Box::new(DecimalDistinctAvgAccumulator::::with_decimal_params( + *scale, + *target_precision, + *target_scale, + ))), + + (dt, return_type) => exec_err!( + "AVG(DISTINCT) for ({} --> {}) not supported", + dt, + return_type + ), } } else { - match (&data_type, acc_args.return_field.data_type()) { + match (&data_type, acc_args.return_type()) { (Float64, Float64) => Ok(Box::::default()), ( Decimal128(sum_precision, sum_scale), @@ -161,22 +187,31 @@ impl AggregateUDFImpl for Avg { })) } - _ => exec_err!( - "AvgAccumulator for ({} --> {})", - &data_type, - acc_args.return_field.data_type() - ), + (dt, return_type) => { + exec_err!("AvgAccumulator for ({} --> {})", dt, return_type) + } } } } fn state_fields(&self, args: StateFieldsArgs) -> Result> { if args.is_distinct { - // Copied from datafusion_functions_aggregate::sum::Sum::state_fields + // Decimal accumulator actually uses a different precision during accumulation, + // see DecimalDistinctAvgAccumulator::with_decimal_params + let dt = match args.input_fields[0].data_type() { + DataType::Decimal128(_, scale) => { + DataType::Decimal128(DECIMAL128_MAX_PRECISION, *scale) + } + DataType::Decimal256(_, scale) => { + DataType::Decimal256(DECIMAL256_MAX_PRECISION, *scale) + } + _ => args.return_type().clone(), + }; + // Similar to datafusion_functions_aggregate::sum::Sum::state_fields // since the accumulator uses DistinctSumAccumulator internally. Ok(vec![Field::new_list( format_state_name(args.name, "avg distinct"), - Field::new_list_field(args.return_type().clone(), true), + Field::new_list_field(dt, true), false, ) .into()]) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index caf8d637ec45..eed3721078c7 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -7322,6 +7322,38 @@ SELECT a, median(b), arrow_typeof(median(b)) FROM group_median_all_nulls GROUP B group0 NULL Int32 group1 NULL Int32 +statement ok +create table t_decimal (c decimal(10, 4)) as values (100.00), (125.00), (175.00), (200.00), (200.00), (300.00), (null), (null); + +# Test avg_distinct for Decimal128 +query RT +select avg(distinct c), arrow_typeof(avg(distinct c)) from t_decimal; +---- +180 Decimal128(14, 8) + +statement ok +drop table t_decimal; + +# Test avg_distinct for Decimal256 +statement ok +create table t_decimal256 (c decimal(50, 2)) as values + (100.00), + (125.00), + (175.00), + (200.00), + (200.00), + (300.00), + (null), + (null); + +query RT +select avg(distinct c), arrow_typeof(avg(distinct c)) from t_decimal256; +---- +180 Decimal256(54, 6) + +statement ok +drop table t_decimal256; + query I with test AS (SELECT i as c1, i + 1 as c2 FROM generate_series(1, 10) t(i)) select count(*) from test WHERE 1 = 1; @@ -7444,55 +7476,65 @@ FROM (VALUES ('a'), ('d'), ('c'), ('a')) t(a_varchar); # distinct average statement ok -create table distinct_avg (a int, b double) as values - (3, null), - (2, null), - (5, 100.5), - (5, 1.0), - (5, 44.112), - (null, 1.0), - (5, 100.5), - (1, 4.09), - (5, 100.5), - (5, 100.5), - (4, null), - (null, null) +create table distinct_avg (a int, b double, c decimal(10, 4), d decimal(50, 2)) as values + (3, null, 100.2562, 90251.21), + (2, null, 100.2562, null), + (5, 100.5, null, 10000000.11), + (5, 1.0, 100.2563, -1.0), + (5, 44.112, -132.12, null), + (null, 1.0, 100.2562, 90251.21), + (5, 100.5, -100.2562, -10000000.11), + (1, 4.09, 4222.124, 0.0), + (5, 100.5, null, 10000000.11), + (5, 100.5, 1.1, 1.0), + (4, null, 4222.124, null), + (null, null, null, null) ; # Need two columns to ensure single_distinct_to_group_by rule doesn't kick in, so we know our actual avg(distinct) code is being tested -query RTRTRR +query RTRTRTRTRRRR select avg(distinct a), arrow_typeof(avg(distinct a)), avg(distinct b), arrow_typeof(avg(distinct b)), + avg(distinct c), + arrow_typeof(avg(distinct c)), + avg(distinct d), + arrow_typeof(avg(distinct d)), avg(a), - avg(b) + avg(b), + avg(c), + avg(d) from distinct_avg; ---- -3 Float64 37.4255 Float64 4 56.52525 +3 Float64 37.4255 Float64 698.56005 Decimal128(14, 8) 15041.868333 Decimal256(54, 6) 4 56.52525 957.11074444 1272562.81625 -query RR rowsort +query RRRR rowsort select avg(distinct a), - avg(distinct b) + avg(distinct b), + avg(distinct c), + avg(distinct d) from distinct_avg group by b; ---- -1 4.09 -3 NULL -5 1 -5 100.5 -5 44.112 +1 4.09 4222.124 0 +3 NULL 2161.1901 90251.21 +5 1 100.25625 45125.105 +5 100.5 -49.5781 0.333333 +5 44.112 -132.12 NULL -query RR +query RRRR select avg(distinct a), - avg(distinct b) + avg(distinct b), + avg(distinct c), + avg(distinct d) from distinct_avg -where a is null and b is null; +where a is null and b is null and c is null and d is null; ---- -NULL NULL +NULL NULL NULL NULL statement ok drop table distinct_avg;