diff --git a/datafusion/functions-aggregate/src/variance.rs b/datafusion/functions-aggregate/src/variance.rs index e6978c15d0bf7..9e35bf0a2bea7 100644 --- a/datafusion/functions-aggregate/src/variance.rs +++ b/datafusion/functions-aggregate/src/variance.rs @@ -18,20 +18,21 @@ //! [`VarianceSample`]: variance sample aggregations. //! [`VariancePopulation`]: variance population aggregations. -use arrow::datatypes::FieldRef; +use arrow::datatypes::{FieldRef, Float64Type}; use arrow::{ array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array}, buffer::NullBuffer, compute::kernels::cast, datatypes::{DataType, Field}, }; -use datafusion_common::{Result, ScalarValue, downcast_value, not_impl_err, plan_err}; +use datafusion_common::{Result, ScalarValue, downcast_value, plan_err}; use datafusion_expr::{ Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature, Volatility, function::{AccumulatorArgs, StateFieldsArgs}, utils::format_state_name, }; +use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer; use datafusion_functions_aggregate_common::{ aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType, }; @@ -110,19 +111,35 @@ impl AggregateUDFImpl for VarianceSample { fn state_fields(&self, args: StateFieldsArgs) -> Result> { let name = args.name; - Ok(vec![ - Field::new(format_state_name(name, "count"), DataType::UInt64, true), - Field::new(format_state_name(name, "mean"), DataType::Float64, true), - Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ] - .into_iter() - .map(Arc::new) - .collect()) + match args.is_distinct { + false => Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean"), DataType::Float64, true), + Field::new(format_state_name(name, "m2"), DataType::Float64, true), + ] + .into_iter() + .map(Arc::new) + .collect()), + true => { + let field = Field::new_list_field(DataType::Float64, true); + let state_name = "distinct_var"; + Ok(vec![ + Field::new( + format_state_name(name, state_name), + DataType::List(Arc::new(field)), + true, + ) + .into(), + ]) + } + } } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if acc_args.is_distinct { - return not_impl_err!("VAR(DISTINCT) aggregations are not available"); + return Ok(Box::new(DistinctVarianceAccumulator::new( + StatsType::Sample, + ))); } Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?)) @@ -206,20 +223,38 @@ impl AggregateUDFImpl for VariancePopulation { } fn state_fields(&self, args: StateFieldsArgs) -> Result> { - let name = args.name; - Ok(vec![ - Field::new(format_state_name(name, "count"), DataType::UInt64, true), - Field::new(format_state_name(name, "mean"), DataType::Float64, true), - Field::new(format_state_name(name, "m2"), DataType::Float64, true), - ] - .into_iter() - .map(Arc::new) - .collect()) + match args.is_distinct { + false => { + let name = args.name; + Ok(vec![ + Field::new(format_state_name(name, "count"), DataType::UInt64, true), + Field::new(format_state_name(name, "mean"), DataType::Float64, true), + Field::new(format_state_name(name, "m2"), DataType::Float64, true), + ] + .into_iter() + .map(Arc::new) + .collect()) + } + true => { + let field = Field::new_list_field(DataType::Float64, true); + let state_name = "distinct_var"; + Ok(vec![ + Field::new( + format_state_name(args.name, state_name), + DataType::List(Arc::new(field)), + true, + ) + .into(), + ]) + } + } } fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { if acc_args.is_distinct { - return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available"); + return Ok(Box::new(DistinctVarianceAccumulator::new( + StatsType::Population, + ))); } Ok(Box::new(VarianceAccumulator::try_new( @@ -581,6 +616,73 @@ impl GroupsAccumulator for VarianceGroupsAccumulator { } } +#[derive(Debug)] +pub struct DistinctVarianceAccumulator { + distinct_values: GenericDistinctBuffer, + stat_type: StatsType, +} + +impl DistinctVarianceAccumulator { + pub fn new(stat_type: StatsType) -> Self { + Self { + distinct_values: GenericDistinctBuffer::::new(DataType::Float64), + stat_type, + } + } +} + +impl Accumulator for DistinctVarianceAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let cast_values = cast(&values[0], &DataType::Float64)?; + self.distinct_values + .update_batch(vec![cast_values].as_ref()) + } + + fn evaluate(&mut self) -> Result { + let values = self + .distinct_values + .values + .iter() + .map(|v| v.0) + .collect::>(); + + let count = match self.stat_type { + StatsType::Sample => { + if !values.is_empty() { + values.len() - 1 + } else { + 0 + } + } + StatsType::Population => values.len(), + }; + + let mean = values.iter().sum::() / values.len() as f64; + let m2 = values.iter().map(|x| (x - mean) * (x - mean)).sum::(); + + Ok(ScalarValue::Float64(match values.len() { + 0 => None, + 1 => match self.stat_type { + StatsType::Population => Some(0.0), + StatsType::Sample => None, + }, + _ => Some(m2 / count as f64), + })) + } + + fn size(&self) -> usize { + size_of_val(self) + self.distinct_values.size() + } + + fn state(&mut self) -> Result> { + self.distinct_values.state() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.distinct_values.merge_batch(states) + } +} + #[cfg(test)] mod tests { use datafusion_expr::EmitTo; diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index f6ce68917e03b..df980ab863362 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -700,8 +700,10 @@ SELECT var(distinct c2) FROM aggregate_test_100 ---- 2.5 -statement error DataFusion error: This feature is not implemented: VAR\(DISTINCT\) aggregations are not available +query RR SELECT var(c2), var(distinct c2) FROM aggregate_test_100 +---- +1.886363636364 2.5 # csv_query_distinct_variance_population query R @@ -709,8 +711,10 @@ SELECT var_pop(distinct c2) FROM aggregate_test_100 ---- 2 -statement error DataFusion error: This feature is not implemented: VAR_POP\(DISTINCT\) aggregations are not available +query RR SELECT var_pop(c2), var_pop(distinct c2) FROM aggregate_test_100 +---- +1.8675 2 # csv_query_variance_5 query R