Skip to content

Commit

Permalink
approx_quantile() aggregation function (#1539)
Browse files Browse the repository at this point in the history
* feat: implement TDigest for approx quantile

Adds a [TDigest] implementation providing approximate quantile
estimations of large inputs using a small amount of (bounded) memory.

A TDigest is most accurate near either "end" of the quantile range (that
is, 0.1, 0.9, 0.95, etc) due to the use of a scalaing function that
increases resolution at the tails. The paper claims single digit part
per million errors for q ≤ 0.001 or q ≥ 0.999 using 100 centroids, and
in practice I have found accuracy to be more than acceptable for an
apprixmate function across the entire quantile range.

The implementation is a modified copy of
https://github.com/MnO2/t-digest, itself a Rust port of [Facebook's C++
implementation]. Both Facebook's implementation, and Mn02's Rust port
are Apache 2.0 licensed.

[TDigest]: https://arxiv.org/abs/1902.04023
[Facebook's C++ implementation]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h

* feat: approx_quantile aggregation

Adds the ApproxQuantile physical expression, plumbing & test cases.

The function signature is:

	approx_quantile(column, quantile)

Where column can be any numeric type (that can be cast to a float64) and
quantile is a float64 literal between 0 and 1.

* feat: approx_quantile dataframe function

Adds the approx_quantile() dataframe function, and exports it in the
prelude.

* refactor: bastilla approx_quantile support

Adds bastilla wire encoding for approx_quantile.

Adding support for this required modifying the AggregateExprNode proto
message to support propigating multiple LogicalExprNode aggregate
arguments - all the existing aggregations take a single argument, so
this wasn't needed before.

This commit adds "repeated" to the expr field, which I believe is
backwards compatible as described here:

	https://developers.google.com/protocol-buffers/docs/proto3#updating

Specifically, adding "repeated" to an existing message field:

	"For ... message fields, optional is compatible with repeated"

No existing tests needed fixing, and a new roundtrip test is included
that covers the change to allow multiple expr.

* refactor: use input type as return type

Casts the calculated quantile value to the same type as the input data.

* fixup! refactor: bastilla approx_quantile support

* refactor: rebase onto main

* refactor: validate quantile value

Ensures the quantile values is between 0 and 1, emitting a plan error if
not.

* refactor: rename to approx_percentile_cont

* refactor: clippy lints
  • Loading branch information
domodwyer authored Jan 31, 2022
1 parent 7bec762 commit cfb655d
Show file tree
Hide file tree
Showing 16 changed files with 1,485 additions and 47 deletions.
3 changes: 2 additions & 1 deletion ballista/rust/core/proto/ballista.proto
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,12 @@ enum AggregateFunction {
STDDEV=11;
STDDEV_POP=12;
CORRELATION=13;
APPROX_PERCENTILE_CONT = 14;
}

message AggregateExprNode {
AggregateFunction aggr_function = 1;
LogicalExprNode expr = 2;
repeated LogicalExprNode expr = 2;
}

enum BuiltInWindowFunction {
Expand Down
6 changes: 5 additions & 1 deletion ballista/rust/core/src/serde/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,11 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {

Ok(Expr::AggregateFunction {
fun,
args: vec![parse_required_expr(&expr.expr)?],
args: expr
.expr
.iter()
.map(|e| e.try_into())
.collect::<Result<Vec<_>, _>>()?,
distinct: false, //TODO
})
}
Expand Down
21 changes: 16 additions & 5 deletions ballista/rust/core/src/serde/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@ mod roundtrip_tests {
use super::super::{super::error::Result, protobuf};
use crate::error::BallistaError;
use core::panic;
use datafusion::arrow::datatypes::UnionMode;
use datafusion::logical_plan::Repartition;
use datafusion::{
arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit},
arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode},
datasource::object_store::local::LocalFileSystem,
logical_plan::{
col, CreateExternalTable, Expr, LogicalPlan, LogicalPlanBuilder,
Partitioning, ToDFSchema,
Partitioning, Repartition, ToDFSchema,
},
physical_plan::functions::BuiltinScalarFunction::Sqrt,
physical_plan::{aggregates, functions::BuiltinScalarFunction::Sqrt},
prelude::*,
scalar::ScalarValue,
sql::parser::FileType,
Expand Down Expand Up @@ -1001,4 +999,17 @@ mod roundtrip_tests {

Ok(())
}

#[test]
fn roundtrip_approx_percentile_cont() -> Result<()> {
let test_expr = Expr::AggregateFunction {
fun: aggregates::AggregateFunction::ApproxPercentileCont,
args: vec![col("bananas"), lit(0.42)],
distinct: false,
};

roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr);

Ok(())
}
}
14 changes: 10 additions & 4 deletions ballista/rust/core/src/serde/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,9 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
AggregateFunction::ApproxDistinct => {
protobuf::AggregateFunction::ApproxDistinct
}
AggregateFunction::ApproxPercentileCont => {
protobuf::AggregateFunction::ApproxPercentileCont
}
AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg,
AggregateFunction::Min => protobuf::AggregateFunction::Min,
AggregateFunction::Max => protobuf::AggregateFunction::Max,
Expand All @@ -1099,11 +1102,13 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
}
};

let arg = &args[0];
let aggregate_expr = Box::new(protobuf::AggregateExprNode {
let aggregate_expr = protobuf::AggregateExprNode {
aggr_function: aggr_function.into(),
expr: Some(Box::new(arg.try_into()?)),
});
expr: args
.iter()
.map(|v| v.try_into())
.collect::<Result<Vec<_>, _>>()?,
};
Ok(protobuf::LogicalExprNode {
expr_type: Some(ExprType::AggregateExpr(aggregate_expr)),
})
Expand Down Expand Up @@ -1334,6 +1339,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
AggregateFunction::Stddev => Self::Stddev,
AggregateFunction::StddevPop => Self::StddevPop,
AggregateFunction::Correlation => Self::Correlation,
AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont,
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions ballista/rust/core/src/serde/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev,
protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop,
protobuf::AggregateFunction::Correlation => AggregateFunction::Correlation,
protobuf::AggregateFunction::ApproxPercentileCont => {
AggregateFunction::ApproxPercentileCont
}
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions datafusion/src/logical_plan/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,15 @@ pub fn approx_distinct(expr: Expr) -> Expr {
}
}

/// Calculate an approximation of the specified `percentile` for `expr`.
pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
Expr::AggregateFunction {
fun: aggregates::AggregateFunction::ApproxPercentileCont,
distinct: false,
args: vec![expr, percentile],
}
}

// TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many
// varying arity functions
/// Create an convenience function representing a unary scalar function
Expand Down
14 changes: 7 additions & 7 deletions datafusion/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,13 @@ pub use builder::{
pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema};
pub use display::display_schema;
pub use expr::{
abs, acos, and, approx_distinct, array, ascii, asin, atan, avg, binary_expr,
bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr,
combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf,
create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list,
initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim,
max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random,
regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan,
avg, binary_expr, bit_length, btrim, case, ceil, character_length, chr, col,
columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct,
create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields,
floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2,
lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length,
or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512,
signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex,
translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when,
Expand Down
87 changes: 83 additions & 4 deletions datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
//! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64.
use super::{
functions::{Signature, Volatility},
functions::{Signature, TypeSignature, Volatility},
Accumulator, AggregateExpr, PhysicalExpr,
};
use crate::error::{DataFusionError, Result};
Expand Down Expand Up @@ -80,6 +80,8 @@ pub enum AggregateFunction {
CovariancePop,
/// Correlation
Correlation,
/// Approximate continuous percentile function
ApproxPercentileCont,
}

impl fmt::Display for AggregateFunction {
Expand Down Expand Up @@ -110,6 +112,7 @@ impl FromStr for AggregateFunction {
"covar_samp" => AggregateFunction::Covariance,
"covar_pop" => AggregateFunction::CovariancePop,
"corr" => AggregateFunction::Correlation,
"approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
_ => {
return Err(DataFusionError::Plan(format!(
"There is no built-in function named {}",
Expand Down Expand Up @@ -157,6 +160,7 @@ pub fn return_type(
coerced_data_types[0].clone(),
true,
)))),
AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()),
}
}

Expand Down Expand Up @@ -331,6 +335,20 @@ pub fn create_aggregate_expr(
"CORR(DISTINCT) aggregations are not available".to_string(),
));
}
(AggregateFunction::ApproxPercentileCont, false) => {
Arc::new(expressions::ApproxPercentileCont::new(
// Pass in the desired percentile expr
coerced_phy_exprs,
name,
return_type,
)?)
}
(AggregateFunction::ApproxPercentileCont, true) => {
return Err(DataFusionError::NotImplemented(
"approx_percentile_cont(DISTINCT) aggregations are not available"
.to_string(),
));
}
})
}

Expand Down Expand Up @@ -389,17 +407,25 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature {
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ApproxPercentileCont => Signature::one_of(
// Accept any numeric value paired with a float64 percentile
NUMERICS
.iter()
.map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64]))
.collect(),
Volatility::Immutable,
),
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
use crate::physical_plan::expressions::{
ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, DistinctArrayAgg,
DistinctCount, Max, Min, Stddev, Sum, Variance,
ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count,
Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance,
};
use crate::{error::Result, scalar::ScalarValue};

#[test]
fn test_count_arragg_approx_expr() -> Result<()> {
Expand Down Expand Up @@ -513,6 +539,59 @@ mod tests {
Ok(())
}

#[test]
fn test_agg_approx_percentile_phy_expr() {
for data_type in NUMERICS {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
),
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))),
];
let result_agg_phy_exprs = create_aggregate_expr(
&AggregateFunction::ApproxPercentileCont,
false,
&input_phy_exprs[..],
&input_schema,
"c1",
)
.expect("failed to create aggregate expr");

assert!(result_agg_phy_exprs.as_any().is::<ApproxPercentileCont>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", data_type.clone(), false),
result_agg_phy_exprs.field().unwrap()
);
}
}

#[test]
fn test_agg_approx_percentile_invalid_phy_expr() {
for data_type in NUMERICS {
let input_schema =
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
Arc::new(
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
),
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))),
];
let err = create_aggregate_expr(
&AggregateFunction::ApproxPercentileCont,
false,
&input_phy_exprs[..],
&input_schema,
"c1",
)
.expect_err("should fail due to invalid percentile");

assert!(matches!(err, DataFusionError::Plan(_)));
}
}

#[test]
fn test_min_max_expr() -> Result<()> {
let funcs = vec![AggregateFunction::Min, AggregateFunction::Max];
Expand Down
Loading

0 comments on commit cfb655d

Please sign in to comment.