Skip to content

Commit

Permalink
Move Median to functions-aggregate and Introduce Numeric signature (#…
Browse files Browse the repository at this point in the history
…10644)

* introduce median udaf

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm agg median

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm old median

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* introduce numeric signature

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* address comment

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix doc

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add proto roundtrip

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed May 26, 2024
1 parent e40f50b commit 26b44f4
Show file tree
Hide file tree
Showing 28 changed files with 258 additions and 133 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,11 @@ use datafusion_common::{
};
use datafusion_expr::lit;
use datafusion_expr::{
avg, count, max, median, min, stddev, utils::COUNT_STAR_EXPANSION,
avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null, sum};
use datafusion_functions_aggregate::expr_fn::median;

use async_trait::async_trait;

Expand Down
9 changes: 1 addition & 8 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ pub enum AggregateFunction {
Max,
/// Average
Avg,
/// Median
Median,
/// Approximate distinct function
ApproxDistinct,
/// Aggregation into an array
Expand Down Expand Up @@ -114,7 +112,6 @@ impl AggregateFunction {
Min => "MIN",
Max => "MAX",
Avg => "AVG",
Median => "MEDIAN",
ApproxDistinct => "APPROX_DISTINCT",
ArrayAgg => "ARRAY_AGG",
FirstValue => "FIRST_VALUE",
Expand Down Expand Up @@ -168,7 +165,6 @@ impl FromStr for AggregateFunction {
"count" => AggregateFunction::Count,
"max" => AggregateFunction::Max,
"mean" => AggregateFunction::Avg,
"median" => AggregateFunction::Median,
"min" => AggregateFunction::Min,
"sum" => AggregateFunction::Sum,
"array_agg" => AggregateFunction::ArrayAgg,
Expand Down Expand Up @@ -275,9 +271,7 @@ impl AggregateFunction {
AggregateFunction::ApproxPercentileContWithWeight => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ApproxMedian | AggregateFunction::Median => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::FirstValue
| AggregateFunction::LastValue
Expand Down Expand Up @@ -335,7 +329,6 @@ impl AggregateFunction {
| AggregateFunction::VariancePop
| AggregateFunction::Stddev
| AggregateFunction::StddevPop
| AggregateFunction::Median
| AggregateFunction::ApproxMedian
| AggregateFunction::FirstValue
| AggregateFunction::LastValue => {
Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -296,18 +296,6 @@ pub fn approx_distinct(expr: Expr) -> Expr {
))
}

/// Calculate the median for `expr`.
pub fn median(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Median,
vec![expr],
false,
None,
None,
None,
))
}

/// Calculate an approximation of the median for `expr`.
pub fn approx_median(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ pub struct AccumulatorArgs<'a> {

/// The number of arguments the aggregate function takes.
pub args_num: usize,

/// The name of the expression
pub name: &'a str,
}

/// [`StateFieldsArgs`] contains information about the fields that an
Expand Down
14 changes: 14 additions & 0 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ pub enum TypeSignature {
OneOf(Vec<TypeSignature>),
/// Specifies Signatures for array functions
ArraySignature(ArrayFunctionSignature),
/// Fixed number of arguments of numeric types.
/// See <https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_numeric> to know which type is considered numeric
Numeric(usize),
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -178,6 +181,9 @@ impl TypeSignature {
.collect::<Vec<String>>()
.join(", ")]
}
TypeSignature::Numeric(num) => {
vec![format!("Numeric({})", num)]
}
TypeSignature::Exact(types) => {
vec![Self::join_types(types, ", ")]
}
Expand Down Expand Up @@ -259,6 +265,14 @@ impl Signature {
volatility,
}
}

pub fn numeric(num: usize, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Numeric(num),
volatility,
}
}

/// An arbitrary number of arguments of any type.
pub fn variadic_any(volatility: Volatility) -> Self {
Self {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/expr/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ impl TreeNode for Expr {
Ok(Expr::AggregateFunction(AggregateFunction::new_udf(
fun,
new_args,
false,
distinct,
new_filter,
new_order_by,
null_treatment,
Expand Down
10 changes: 7 additions & 3 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::Median
| AggregateFunction::FirstValue
| AggregateFunction::LastValue => Ok(input_types.to_vec()),
AggregateFunction::FirstValue | AggregateFunction::LastValue => {
Ok(input_types.to_vec())
}
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
AggregateFunction::StringAgg => {
Expand Down Expand Up @@ -355,6 +355,10 @@ pub fn check_arg_count(
);
}
}
TypeSignature::UserDefined | TypeSignature::Numeric(_) => {
// User-defined signature is validated in `coerce_types`
// Numreic signature is validated in `get_valid_types`
}
_ => {
return internal_err!(
"Aggregate functions do not support this {signature:?}"
Expand Down
32 changes: 32 additions & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,38 @@ fn get_valid_types(
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::Numeric(number) => {
if *number < 1 {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if *number != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
number,
current_types.len()
);
}

let mut valid_type = current_types.first().unwrap().clone();
for t in current_types.iter().skip(1) {
if let Some(coerced_type) =
comparison_binary_numeric_coercion(&valid_type, t)
{
valid_type = coerced_type;
} else {
return plan_err!(
"{} and {} are not coercible to a common numeric type",
valid_type,
t
);
}
}

vec![vec![valid_type; *number]]
}
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
Expand Down
3 changes: 2 additions & 1 deletion datafusion/expr/src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,10 @@ impl AggregateUDF {
self.inner.create_groups_accumulator()
}

pub fn coerce_types(&self, _args: &[DataType]) -> Result<Vec<DataType>> {
pub fn coerce_types(&self, _arg_types: &[DataType]) -> Result<Vec<DataType>> {
not_impl_err!("coerce_types not implemented for {:?} yet", self.name())
}

/// Do the function rewrite
///
/// See [`AggregateUDFImpl::simplify`] for more details.
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ path = "src/lib.rs"

[dependencies]
arrow = { workspace = true }
arrow-schema = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
Expand Down
3 changes: 3 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ pub mod macros;

pub mod covariance;
pub mod first_last;
pub mod median;

use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
Expand All @@ -68,6 +69,7 @@ use std::sync::Arc;
pub mod expr_fn {
pub use super::covariance::covar_samp;
pub use super::first_last::first_value;
pub use super::median::median;
}

/// Registers all enabled packages with a [`FunctionRegistry`]
Expand All @@ -76,6 +78,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
first_last::first_value_udaf(),
covariance::covar_samp_udaf(),
covariance::covar_pop_udaf(),
median::median_udaf(),
];

functions.into_iter().try_for_each(|udf| {
Expand Down
Loading

0 comments on commit 26b44f4

Please sign in to comment.