Skip to content

Commit

Permalink
feat: approx_quantile aggregation
Browse files Browse the repository at this point in the history
Adds the ApproxQuantile physical expression, plumbing & test cases.

The function signature is:

	approx_quantile(column, quantile)

Where column can be any numeric type (that can be cast to a float64) and
quantile is a float64 literal between 0 and 1.
  • Loading branch information
domodwyer committed Jan 11, 2022
1 parent b72d21c commit d5fc006
Show file tree
Hide file tree
Showing 5 changed files with 537 additions and 23 deletions.
65 changes: 61 additions & 4 deletions datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
//! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64.

use super::{
functions::{Signature, Volatility},
functions::{Signature, TypeSignature, Volatility},
Accumulator, AggregateExpr, PhysicalExpr,
};
use crate::error::{DataFusionError, Result};
Expand Down Expand Up @@ -74,6 +74,8 @@ pub enum AggregateFunction {
Stddev,
/// Standard Deviation (Population)
StddevPop,
/// Approximate quantile function
ApproxQuantile,
}

impl fmt::Display for AggregateFunction {
Expand All @@ -100,6 +102,7 @@ impl FromStr for AggregateFunction {
"stddev" => AggregateFunction::Stddev,
"stddev_samp" => AggregateFunction::Stddev,
"stddev_pop" => AggregateFunction::StddevPop,
"approx_quantile" => AggregateFunction::ApproxQuantile,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand Down Expand Up @@ -142,6 +145,7 @@ pub fn return_type(
coerced_data_types[0].clone(),
true,
)))),
AggregateFunction::ApproxQuantile => Ok(DataType::Float64),
}
}

Expand Down Expand Up @@ -279,6 +283,19 @@ pub fn create_aggregate_expr(
"STDDEV_POP(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::ApproxQuantile, false) => {
Arc::new(expressions::ApproxQuantile::new(
// Pass in the desired quantile expr
coerced_phy_exprs,
name,
return_type,
)?)
}
(AggregateFunction::ApproxQuantile, true) => {
return Err(DataFusionError::NotImplemented(
"approx_quantile(DISTINCT) aggregations are not available".to_string(),
));
}
})
}

Expand Down Expand Up @@ -331,17 +348,28 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
| AggregateFunction::StddevPop => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ApproxQuantile => Signature::one_of(
// Accept any numeric value paired with a float64 quantile
NUMERICS
.iter()
.map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64]))
.collect(),
Volatility::Immutable,
),
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::error::DataFusionError::NotImplemented;
use crate::error::Result;
use crate::physical_plan::distinct_expressions::DistinctCount;
use crate::physical_plan::expressions::{
ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Stddev, Sum, Variance,
ApproxDistinct, ApproxQuantile, ArrayAgg, Avg, Count, Max, Min, Stddev, Sum,
Variance,
};
use crate::{
error::{DataFusionError::NotImplemented, Result},
scalar::ScalarValue,
};

#[test]
Expand Down Expand Up @@ -458,6 +486,35 @@ mod tests {
Ok(())
}

#[test]
fn test_agg_approx_quantile_phy_expr() {
for data_type in NUMERICS {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
),
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))),
];
let result_agg_phy_exprs = create_aggregate_expr(
&AggregateFunction::ApproxQuantile,
false,
&input_phy_exprs[..],
&input_schema,
"c1",
)
.expect("failed to create aggregate expr");

assert!(result_agg_phy_exprs.as_any().is::<ApproxQuantile>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, false),
result_agg_phy_exprs.field().unwrap()
);
}
}

#[test]
fn test_min_max_expr() -> Result<()> {
let funcs = vec![AggregateFunction::Min, AggregateFunction::Max];
Expand Down
114 changes: 95 additions & 19 deletions datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

//! Support the coercion rule for aggregate function.

use crate::arrow::datatypes::Schema;
use crate::error::{DataFusionError, Result};
use crate::physical_plan::aggregates::AggregateFunction;
use crate::physical_plan::expressions::{
Expand All @@ -26,6 +25,10 @@ use crate::physical_plan::expressions::{
};
use crate::physical_plan::functions::{Signature, TypeSignature};
use crate::physical_plan::PhysicalExpr;
use crate::{
arrow::datatypes::Schema,
physical_plan::expressions::is_approx_quantile_supported_arg_type,
};
use arrow::datatypes::DataType;
use std::ops::Deref;
use std::sync::Arc;
Expand All @@ -37,24 +40,9 @@ pub(crate) fn coerce_types(
input_types: &[DataType],
signature: &Signature,
) -> Result<Vec<DataType>> {
match signature.type_signature {
TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
if input_types.len() != agg_count {
return Err(DataFusionError::Plan(format!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
agg_count,
input_types.len()
)));
}
}
_ => {
return Err(DataFusionError::Internal(format!(
"Aggregate functions do not support this {:?}",
signature
)));
}
};
// Validate input_types matches (at least one of) the func signature.
check_arg_count(agg_fun, input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
Ok(input_types.to_vec())
Expand Down Expand Up @@ -123,7 +111,75 @@ pub(crate) fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxQuantile => {
if !is_approx_quantile_supported_arg_type(&input_types[0]) {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun, input_types[0]
)));
}
if !matches!(input_types[1], DataType::Float64) {
return Err(DataFusionError::Plan(format!(
"The quantile argument for {:?} must be Float64, not {:?}.",
agg_fun, input_types[1]
)));
}
Ok(input_types.to_vec())
}
}
}

/// Validate the length of `input_types` matches the `signature` for `agg_fun`.
///
/// This method DOES NOT validate the argument types - only that (at least one,
/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
/// number of input types.
fn check_arg_count(
agg_fun: &AggregateFunction,
input_types: &[DataType],
signature: &TypeSignature,
) -> Result<()> {
match signature {
TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
if input_types.len() != *agg_count {
return Err(DataFusionError::Plan(format!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
agg_count,
input_types.len()
)));
}
}
TypeSignature::Exact(types) => {
if types.len() != input_types.len() {
return Err(DataFusionError::Plan(format!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
types.len(),
input_types.len()
)));
}
}
TypeSignature::OneOf(variants) => {
let ok = variants
.iter()
.any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
if !ok {
return Err(DataFusionError::Plan(format!(
"The function {:?} does not accept {:?} function arguments.",
agg_fun,
input_types.len()
)));
}
}
_ => {
return Err(DataFusionError::Internal(format!(
"Aggregate functions do not support this {:?}",
signature
)));
}
}
Ok(())
}

fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
Expand Down Expand Up @@ -239,5 +295,25 @@ mod tests {
assert_eq!(*input_type, result.unwrap());
}
}

// ApproxQuantile input types
let input_types = vec![
vec![DataType::Int8, DataType::Float64],
vec![DataType::Int16, DataType::Float64],
vec![DataType::Int32, DataType::Float64],
vec![DataType::Int64, DataType::Float64],
vec![DataType::UInt8, DataType::Float64],
vec![DataType::UInt16, DataType::Float64],
vec![DataType::UInt32, DataType::Float64],
vec![DataType::UInt64, DataType::Float64],
vec![DataType::Float32, DataType::Float64],
vec![DataType::Float64, DataType::Float64],
];
for input_type in &input_types {
let signature = aggregates::signature(&AggregateFunction::ApproxQuantile);
let result =
coerce_types(&AggregateFunction::ApproxQuantile, input_type, &signature);
assert_eq!(*input_type, result.unwrap());
}
}
}
Loading

0 comments on commit d5fc006

Please sign in to comment.