Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 123 additions & 21 deletions datafusion/functions-aggregate/src/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,21 @@
//! [`VarianceSample`]: variance sample aggregations.
//! [`VariancePopulation`]: variance population aggregations.

use arrow::datatypes::FieldRef;
use arrow::datatypes::{FieldRef, Float64Type};
use arrow::{
array::{Array, ArrayRef, BooleanArray, Float64Array, UInt64Array},
buffer::NullBuffer,
compute::kernels::cast,
datatypes::{DataType, Field},
};
use datafusion_common::{Result, ScalarValue, downcast_value, not_impl_err, plan_err};
use datafusion_common::{Result, ScalarValue, downcast_value, plan_err};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, Documentation, GroupsAccumulator, Signature,
Volatility,
function::{AccumulatorArgs, StateFieldsArgs},
utils::format_state_name,
};
use datafusion_functions_aggregate_common::utils::GenericDistinctBuffer;
use datafusion_functions_aggregate_common::{
aggregate::groups_accumulator::accumulate::accumulate, stats::StatsType,
};
Expand Down Expand Up @@ -110,19 +111,35 @@ impl AggregateUDFImpl for VarianceSample {

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
let name = args.name;
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean"), DataType::Float64, true),
Field::new(format_state_name(name, "m2"), DataType::Float64, true),
]
.into_iter()
.map(Arc::new)
.collect())
match args.is_distinct {
false => Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean"), DataType::Float64, true),
Field::new(format_state_name(name, "m2"), DataType::Float64, true),
]
.into_iter()
.map(Arc::new)
.collect()),
true => {
let field = Field::new_list_field(DataType::Float64, true);
let state_name = "distinct_var";
Ok(vec![
Field::new(
format_state_name(name, state_name),
DataType::List(Arc::new(field)),
true,
)
.into(),
])
}
}
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if acc_args.is_distinct {
return not_impl_err!("VAR(DISTINCT) aggregations are not available");
return Ok(Box::new(DistinctVarianceAccumulator::new(
StatsType::Sample,
)));
}

Ok(Box::new(VarianceAccumulator::try_new(StatsType::Sample)?))
Expand Down Expand Up @@ -206,20 +223,38 @@ impl AggregateUDFImpl for VariancePopulation {
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
let name = args.name;
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean"), DataType::Float64, true),
Field::new(format_state_name(name, "m2"), DataType::Float64, true),
]
.into_iter()
.map(Arc::new)
.collect())
match args.is_distinct {
false => {
let name = args.name;
Ok(vec![
Field::new(format_state_name(name, "count"), DataType::UInt64, true),
Field::new(format_state_name(name, "mean"), DataType::Float64, true),
Field::new(format_state_name(name, "m2"), DataType::Float64, true),
]
.into_iter()
.map(Arc::new)
.collect())
}
true => {
let field = Field::new_list_field(DataType::Float64, true);
let state_name = "distinct_var";
Ok(vec![
Field::new(
format_state_name(args.name, state_name),
DataType::List(Arc::new(field)),
true,
)
.into(),
])
}
}
}

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
if acc_args.is_distinct {
return not_impl_err!("VAR_POP(DISTINCT) aggregations are not available");
return Ok(Box::new(DistinctVarianceAccumulator::new(
StatsType::Population,
)));
}

Ok(Box::new(VarianceAccumulator::try_new(
Expand Down Expand Up @@ -581,6 +616,73 @@ impl GroupsAccumulator for VarianceGroupsAccumulator {
}
}

#[derive(Debug)]
pub struct DistinctVarianceAccumulator {
distinct_values: GenericDistinctBuffer<Float64Type>,
stat_type: StatsType,
}

impl DistinctVarianceAccumulator {
pub fn new(stat_type: StatsType) -> Self {
Self {
distinct_values: GenericDistinctBuffer::<Float64Type>::new(DataType::Float64),
stat_type,
}
}
}

impl Accumulator for DistinctVarianceAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let cast_values = cast(&values[0], &DataType::Float64)?;
self.distinct_values
.update_batch(vec![cast_values].as_ref())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let values = self
.distinct_values
.values
.iter()
.map(|v| v.0)
.collect::<Vec<_>>();

let count = match self.stat_type {
StatsType::Sample => {
if !values.is_empty() {
values.len() - 1
} else {
0
}
}
StatsType::Population => values.len(),
};

let mean = values.iter().sum::<f64>() / values.len() as f64;
let m2 = values.iter().map(|x| (x - mean) * (x - mean)).sum::<f64>();

Ok(ScalarValue::Float64(match values.len() {
0 => None,
1 => match self.stat_type {
StatsType::Population => Some(0.0),
StatsType::Sample => None,
},
_ => Some(m2 / count as f64),
}))
}

fn size(&self) -> usize {
size_of_val(self) + self.distinct_values.size()
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
self.distinct_values.state()
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.distinct_values.merge_batch(states)
}
}

#[cfg(test)]
mod tests {
use datafusion_expr::EmitTo;
Expand Down
8 changes: 6 additions & 2 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -700,17 +700,21 @@ SELECT var(distinct c2) FROM aggregate_test_100
----
2.5

statement error DataFusion error: This feature is not implemented: VAR\(DISTINCT\) aggregations are not available
query RR
SELECT var(c2), var(distinct c2) FROM aggregate_test_100
----
1.886363636364 2.5

# csv_query_distinct_variance_population
query R
SELECT var_pop(distinct c2) FROM aggregate_test_100
----
2

statement error DataFusion error: This feature is not implemented: VAR_POP\(DISTINCT\) aggregations are not available
query RR
SELECT var_pop(c2), var_pop(distinct c2) FROM aggregate_test_100
----
1.8675 2

# csv_query_variance_5
query R
Expand Down