Skip to content

Commit

Permalink
Support number of centroids in approx_percentile_cont (#3146)
Browse files Browse the repository at this point in the history
* Support number of histogram bins in approx_percentile_cont

* add args check and UT

* add doc
  • Loading branch information
Ted-Jiang authored Aug 17, 2022
1 parent 2aa0a98 commit 929eb6d
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 40 deletions.
48 changes: 48 additions & 0 deletions datafusion/core/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,54 @@ async fn csv_query_approx_percentile_cont_with_weight() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn csv_query_approx_percentile_cont_with_histogram_bins() -> Result<()> {
let ctx = SessionContext::new();
register_aggregate_csv(&ctx).await?;

// compare approx_percentile_cont and approx_percentile_cont_with_weight
let sql = "SELECT c1, approx_percentile_cont(c3, 0.95, 200) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1";
let actual = execute_to_batches(&ctx, sql).await;
let expected = vec![
"+----+--------+",
"| c1 | c3_p95 |",
"+----+--------+",
"| a | 73 |",
"| b | 68 |",
"| c | 122 |",
"| d | 124 |",
"| e | 115 |",
"+----+--------+",
];
assert_batches_eq!(expected, &actual);

let results = plan_and_collect(
&ctx,
"SELECT c1, approx_percentile_cont(c3, 0.95, -1000) AS c3_p95 FROM aggregate_test_100 GROUP BY 1 ORDER BY 1",
)
.await
.unwrap_err();
assert_eq!(results.to_string(), "This feature is not implemented: Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type Int64).");

let results = plan_and_collect(
&ctx,
"SELECT approx_percentile_cont(c3, 0.95, c1) FROM aggregate_test_100",
)
.await
.unwrap_err();
assert_eq!(results.to_string(), "Error during planning: The percentile sample points count for ApproxPercentileCont must be integer, not Utf8.");

let results = plan_and_collect(
&ctx,
"SELECT approx_percentile_cont(c3, 0.95, 111.1) FROM aggregate_test_100",
)
.await
.unwrap_err();
assert_eq!(results.to_string(), "Error during planning: The percentile sample points count for ApproxPercentileCont must be integer, not Float64.");

Ok(())
}

#[tokio::test]
async fn csv_query_sum_crossjoin() {
let ctx = SessionContext::new();
Expand Down
40 changes: 33 additions & 7 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,12 @@ pub fn coerce_types(
agg_fun, input_types[1]
)));
}
if input_types.len() == 3 && !is_integer_arg_type(&input_types[2]) {
return Err(DataFusionError::Plan(format!(
"The percentile sample points count for {:?} must be integer, not {:?}.",
agg_fun, input_types[2]
)));
}
Ok(input_types.to_vec())
}
AggregateFunction::ApproxPercentileContWithWeight => {
Expand Down Expand Up @@ -382,14 +388,20 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ApproxPercentileCont => Signature::one_of(
AggregateFunction::ApproxPercentileCont => {
// Accept any numeric value paired with a float64 percentile
NUMERICS
.iter()
.map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64]))
.collect(),
Volatility::Immutable,
),
let with_tdigest_size = NUMERICS.iter().map(|t| {
TypeSignature::Exact(vec![t.clone(), DataType::Float64, t.clone()])
});
Signature::one_of(
NUMERICS
.iter()
.map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64]))
.chain(with_tdigest_size)
.collect(),
Volatility::Immutable,
)
}
AggregateFunction::ApproxPercentileContWithWeight => Signature::one_of(
// Accept any numeric value paired with a float64 percentile
NUMERICS
Expand Down Expand Up @@ -702,6 +714,20 @@ pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
)
}

pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
| DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
)
}

/// Return `true` if `arg_type` is of a [`DataType`] that the
/// [`AggregateFunction::ApproxPercentileCont`] aggregation can operate on.
pub fn is_approx_percentile_cont_supported_arg_type(arg_type: &DataType) -> bool {
Expand Down
126 changes: 99 additions & 27 deletions datafusion/physical-expr/src/aggregate/approx_percentile_cont.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pub struct ApproxPercentileCont {
input_data_type: DataType,
expr: Vec<Arc<dyn PhysicalExpr>>,
percentile: f64,
tdigest_max_size: Option<usize>,
}

impl ApproxPercentileCont {
Expand All @@ -52,39 +53,35 @@ impl ApproxPercentileCont {
// Arguments should be [ColumnExpr, DesiredPercentileLiteral]
debug_assert_eq!(expr.len(), 2);

// Extract the desired percentile literal
let lit = expr[1]
.as_any()
.downcast_ref::<Literal>()
.ok_or_else(|| {
DataFusionError::Internal(
"desired percentile argument must be float literal".to_string(),
)
})?
.value();
let percentile = match lit {
ScalarValue::Float32(Some(q)) => *q as f64,
ScalarValue::Float64(Some(q)) => *q as f64,
got => return Err(DataFusionError::NotImplemented(format!(
"Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
got
)))
};
let percentile = validate_input_percentile_expr(&expr[1])?;

// Ensure the percentile is between 0 and 1.
if !(0.0..=1.0).contains(&percentile) {
return Err(DataFusionError::Plan(format!(
"Percentile value must be between 0.0 and 1.0 inclusive, {} is invalid",
percentile
)));
}
Ok(Self {
name: name.into(),
input_data_type,
// The physical expr to evaluate during accumulation
expr,
percentile,
tdigest_max_size: None,
})
}

/// Create a new [`ApproxPercentileCont`] aggregate function.
pub fn new_with_max_size(
expr: Vec<Arc<dyn PhysicalExpr>>,
name: impl Into<String>,
input_data_type: DataType,
) -> Result<Self> {
// Arguments should be [ColumnExpr, DesiredPercentileLiteral, TDigestMaxSize]
debug_assert_eq!(expr.len(), 3);
let percentile = validate_input_percentile_expr(&expr[1])?;
let max_size = validate_input_max_size_expr(&expr[2])?;
Ok(Self {
name: name.into(),
input_data_type,
// The physical expr to evaluate during accumulation
expr,
percentile,
tdigest_max_size: Some(max_size),
})
}

Expand All @@ -100,7 +97,13 @@ impl ApproxPercentileCont {
| DataType::Int64
| DataType::Float32
| DataType::Float64) => {
ApproxPercentileAccumulator::new(self.percentile, t.clone())
if let Some(max_size) = self.tdigest_max_size {
ApproxPercentileAccumulator::new_with_max_size(self.percentile, t.clone(), max_size)

}else{
ApproxPercentileAccumulator::new(self.percentile, t.clone())

}
}
other => {
return Err(DataFusionError::NotImplemented(format!(
Expand All @@ -113,6 +116,64 @@ impl ApproxPercentileCont {
}
}

fn validate_input_percentile_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<f64> {
// Extract the desired percentile literal
let lit = expr
.as_any()
.downcast_ref::<Literal>()
.ok_or_else(|| {
DataFusionError::Internal(
"desired percentile argument must be float literal".to_string(),
)
})?
.value();
let percentile = match lit {
ScalarValue::Float32(Some(q)) => *q as f64,
ScalarValue::Float64(Some(q)) => *q as f64,
got => return Err(DataFusionError::NotImplemented(format!(
"Percentile value for 'APPROX_PERCENTILE_CONT' must be Float32 or Float64 literal (got data type {})",
got.get_datatype()
)))
};

// Ensure the percentile is between 0 and 1.
if !(0.0..=1.0).contains(&percentile) {
return Err(DataFusionError::Plan(format!(
"Percentile value must be between 0.0 and 1.0 inclusive, {} is invalid",
percentile
)));
}
Ok(percentile)
}

fn validate_input_max_size_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<usize> {
// Extract the desired percentile literal
let lit = expr
.as_any()
.downcast_ref::<Literal>()
.ok_or_else(|| {
DataFusionError::Internal(
"desired percentile argument must be float literal".to_string(),
)
})?
.value();
let max_size = match lit {
ScalarValue::UInt8(Some(q)) => *q as usize,
ScalarValue::UInt16(Some(q)) => *q as usize,
ScalarValue::UInt32(Some(q)) => *q as usize,
ScalarValue::UInt64(Some(q)) => *q as usize,
ScalarValue::Int32(Some(q)) if *q > 0 => *q as usize,
ScalarValue::Int64(Some(q)) if *q > 0 => *q as usize,
ScalarValue::Int16(Some(q)) if *q > 0 => *q as usize,
ScalarValue::Int8(Some(q)) if *q > 0 => *q as usize,
got => return Err(DataFusionError::NotImplemented(format!(
"Tdigest max_size value for 'APPROX_PERCENTILE_CONT' must be UInt > 0 literal (got data type {}).",
got.get_datatype()
)))
};
Ok(max_size)
}

impl AggregateExpr for ApproxPercentileCont {
fn as_any(&self) -> &dyn Any {
self
Expand Down Expand Up @@ -190,6 +251,18 @@ impl ApproxPercentileAccumulator {
}
}

pub fn new_with_max_size(
percentile: f64,
return_type: DataType,
max_size: usize,
) -> Self {
Self {
digest: TDigest::new(max_size),
percentile,
return_type,
}
}

pub(crate) fn merge_digests(&mut self, digests: &[TDigest]) {
self.digest = TDigest::merge_digests(digests);
}
Expand Down Expand Up @@ -285,7 +358,6 @@ impl ApproxPercentileAccumulator {
}
}
}

impl Accumulator for ApproxPercentileAccumulator {
fn state(&self) -> Result<Vec<AggregateState>> {
Ok(self
Expand Down
21 changes: 15 additions & 6 deletions datafusion/physical-expr/src/aggregate/build_in.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,21 @@ pub fn create_aggregate_expr(
));
}
(AggregateFunction::ApproxPercentileCont, false) => {
Arc::new(expressions::ApproxPercentileCont::new(
// Pass in the desired percentile expr
coerced_phy_exprs,
name,
return_type,
)?)
if coerced_phy_exprs.len() == 2 {
Arc::new(expressions::ApproxPercentileCont::new(
// Pass in the desired percentile expr
coerced_phy_exprs,
name,
return_type,
)?)
} else {
Arc::new(expressions::ApproxPercentileCont::new_with_max_size(
// Pass in the desired percentile expr
coerced_phy_exprs,
name,
return_type,
)?)
}
}
(AggregateFunction::ApproxPercentileCont, true) => {
return Err(DataFusionError::NotImplemented(
Expand Down
6 changes: 6 additions & 0 deletions docs/source/user-guide/sql/aggregate_functions.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ Aggregate functions operate on a set of values to compute a single result. Pleas

It supports raw data as input and build Tdigest sketches during query time, and is approximately equal to `approx_percentile_cont_with_weight(x, 1, p)`.

`approx_percentile_cont(x, p, n) -> x` return the approximate percentile (TDigest) of input values, where `p` is a float64 between 0 and 1 (inclusive),

and `n` (default 100) is the number of centroids in Tdigest which means that if there are `n` or fewer unique values in `x`, you can expect an exact result.

A higher value of `n` results in a more accurate approximation and the cost of higher memory usage.

### approx_percentile_cont_with_weight

`approx_percentile_cont_with_weight(x, w, p) -> x` returns the approximate percentile (TDigest) of input values with weight, where `w` is weight column expression and `p` is a float64 between 0 and 1 (inclusive).
Expand Down

0 comments on commit 929eb6d

Please sign in to comment.