Skip to content

Commit

Permalink
refactor: use input type as return type
Browse files Browse the repository at this point in the history
Casts the calculated quantile value to the same type as the input data.
  • Loading branch information
domodwyer committed Jan 11, 2022
1 parent ce4e1f3 commit 066fcb8
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
4 changes: 2 additions & 2 deletions datafusion/src/physical_plan/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ pub fn return_type(
coerced_data_types[0].clone(),
true,
)))),
AggregateFunction::ApproxQuantile => Ok(DataType::Float64),
AggregateFunction::ApproxQuantile => Ok(coerced_data_types[0].clone()),
}
}

Expand Down Expand Up @@ -455,7 +455,7 @@ mod tests {
assert!(result_agg_phy_exprs.as_any().is::<ApproxQuantile>());
assert_eq!("c1", result_agg_phy_exprs.name());
assert_eq!(
Field::new("c1", DataType::Float64, false),
Field::new("c1", data_type.clone(), false),
result_agg_phy_exprs.field().unwrap()
);
}
Expand Down
34 changes: 26 additions & 8 deletions datafusion/src/physical_plan/expressions/approx_quantile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ impl AggregateExpr for ApproxQuantile {
}

fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, DataType::Float64, false))
Ok(Field::new(&self.name, self.input_data_type.clone(), false))
}

/// See [`TDigest::to_scalar_state()`] for a description of the serialised
Expand Down Expand Up @@ -151,7 +151,9 @@ impl AggregateExpr for ApproxQuantile {

fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
let accumulator: Box<dyn Accumulator> = match &self.input_data_type {
DataType::UInt8
t
@
(DataType::UInt8
| DataType::UInt16
| DataType::UInt32
| DataType::UInt64
Expand All @@ -160,8 +162,8 @@ impl AggregateExpr for ApproxQuantile {
| DataType::Int32
| DataType::Int64
| DataType::Float32
| DataType::Float64 => {
Box::new(ApproxQuantileAccumulator::new(self.quantile))
| DataType::Float64) => {
Box::new(ApproxQuantileAccumulator::new(self.quantile, t.clone()))
}
other => {
return Err(DataFusionError::NotImplemented(format!(
Expand All @@ -182,13 +184,15 @@ impl AggregateExpr for ApproxQuantile {
pub struct ApproxQuantileAccumulator {
digest: TDigest,
quantile: f64,
return_type: DataType,
}

impl ApproxQuantileAccumulator {
pub fn new(quantile: f64) -> Self {
pub fn new(quantile: f64, return_type: DataType) -> Self {
Self {
digest: TDigest::new(100),
quantile,
return_type,
}
}
}
Expand Down Expand Up @@ -283,8 +287,22 @@ impl Accumulator for ApproxQuantileAccumulator {
}

fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::Float64(Some(
self.digest.estimate_quantile(self.quantile),
)))
let q = self.digest.estimate_quantile(self.quantile);

// These acceptable return types MUST match the validation in
// ApproxQuantile::create_accumulator.
Ok(match &self.return_type {
DataType::Int8 => ScalarValue::Int8(Some(q as i8)),
DataType::Int16 => ScalarValue::Int16(Some(q as i16)),
DataType::Int32 => ScalarValue::Int32(Some(q as i32)),
DataType::Int64 => ScalarValue::Int64(Some(q as i64)),
DataType::UInt8 => ScalarValue::UInt8(Some(q as u8)),
DataType::UInt16 => ScalarValue::UInt16(Some(q as u16)),
DataType::UInt32 => ScalarValue::UInt32(Some(q as u32)),
DataType::UInt64 => ScalarValue::UInt64(Some(q as u64)),
DataType::Float32 => ScalarValue::Float32(Some(q as f32)),
DataType::Float64 => ScalarValue::Float64(Some(q as f64)),
v => unreachable!("unexpected return type {:?}", v),
})
}
}
2 changes: 1 addition & 1 deletion datafusion/tests/sql/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ async fn csv_query_approx_quantile() -> Result<()> {
// within 5% of the $actual quantile value.
macro_rules! quantile_test {
($ctx:ident, column=$column:literal, quantile=$quantile:literal, actual=$actual:literal) => {
let sql = format!("SELECT (ABS(1 - approx_quantile({}, {}) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $quantile, $actual);
let sql = format!("SELECT (ABS(1 - CAST(approx_quantile({}, {}) AS DOUBLE) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $quantile, $actual);
let actual = execute_to_batches(&mut ctx, &sql).await;
//
// "+------+",
Expand Down

0 comments on commit 066fcb8

Please sign in to comment.