Skip to content

Commit

Permalink
Move Median to functions-aggregate and Introduce Numeric signature (a…
Browse files Browse the repository at this point in the history
…pache#10644)

* introduce median udaf

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm agg median

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* rm old median

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* introduce numeric signature

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* address comment

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* fix doc

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

* add proto roundtrip

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>

---------

Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
  • Loading branch information
jayzhan211 committed May 26, 2024
1 parent 30e9655 commit d78f63b
Show file tree
Hide file tree
Showing 24 changed files with 202 additions and 133 deletions.
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ use datafusion_common::{
};
use datafusion_expr::lit;
use datafusion_expr::{
avg, count, max, median, min, stddev, utils::COUNT_STAR_EXPANSION,
avg, count, max, min, stddev, utils::COUNT_STAR_EXPANSION,
TableProviderFilterPushDown, UNNAMED_TABLE,
};
use datafusion_expr::{case, is_null};
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_functions_aggregate::expr_fn::median;

use async_trait::async_trait;

Expand Down
9 changes: 1 addition & 8 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ pub enum AggregateFunction {
Max,
/// Average
Avg,
/// Median
Median,
/// Approximate distinct function
ApproxDistinct,
/// Aggregation into an array
Expand Down Expand Up @@ -114,7 +112,6 @@ impl AggregateFunction {
Min => "MIN",
Max => "MAX",
Avg => "AVG",
Median => "MEDIAN",
ApproxDistinct => "APPROX_DISTINCT",
ArrayAgg => "ARRAY_AGG",
FirstValue => "FIRST_VALUE",
Expand Down Expand Up @@ -168,7 +165,6 @@ impl FromStr for AggregateFunction {
"count" => AggregateFunction::Count,
"max" => AggregateFunction::Max,
"mean" => AggregateFunction::Avg,
"median" => AggregateFunction::Median,
"min" => AggregateFunction::Min,
"sum" => AggregateFunction::Sum,
"array_agg" => AggregateFunction::ArrayAgg,
Expand Down Expand Up @@ -275,9 +271,7 @@ impl AggregateFunction {
AggregateFunction::ApproxPercentileContWithWeight => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ApproxMedian | AggregateFunction::Median => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::ApproxMedian => Ok(coerced_data_types[0].clone()),
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::FirstValue
| AggregateFunction::LastValue
Expand Down Expand Up @@ -335,7 +329,6 @@ impl AggregateFunction {
| AggregateFunction::VariancePop
| AggregateFunction::Stddev
| AggregateFunction::StddevPop
| AggregateFunction::Median
| AggregateFunction::ApproxMedian
| AggregateFunction::FirstValue
| AggregateFunction::LastValue => {
Expand Down
12 changes: 0 additions & 12 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,18 +298,6 @@ pub fn approx_distinct(expr: Expr) -> Expr {
))
}

/// Calculate the median for `expr`.
pub fn median(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
aggregate_function::AggregateFunction::Median,
vec![expr],
false,
None,
None,
None,
))
}

/// Calculate an approximation of the median for `expr`.
pub fn approx_median(expr: Expr) -> Expr {
Expr::AggregateFunction(AggregateFunction::new(
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr/src/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ pub struct AccumulatorArgs<'a> {

/// The number of arguments the aggregate function takes.
pub args_num: usize,

/// The name of the expression
pub name: &'a str,
}

/// [`StateFieldsArgs`] contains information about the fields that an
Expand Down
14 changes: 14 additions & 0 deletions datafusion/expr/src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ pub enum TypeSignature {
OneOf(Vec<TypeSignature>),
/// Specifies Signatures for array functions
ArraySignature(ArrayFunctionSignature),
/// Fixed number of arguments of numeric types.
/// See <https://docs.rs/arrow/latest/arrow/datatypes/enum.DataType.html#method.is_numeric> to know which type is considered numeric
Numeric(usize),
}

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -178,6 +181,9 @@ impl TypeSignature {
.collect::<Vec<String>>()
.join(", ")]
}
TypeSignature::Numeric(num) => {
vec![format!("Numeric({})", num)]
}
TypeSignature::Exact(types) => {
vec![Self::join_types(types, ", ")]
}
Expand Down Expand Up @@ -259,6 +265,14 @@ impl Signature {
volatility,
}
}

pub fn numeric(num: usize, volatility: Volatility) -> Self {
Self {
type_signature: TypeSignature::Numeric(num),
volatility,
}
}

/// An arbitrary number of arguments of any type.
pub fn variadic_any(volatility: Volatility) -> Self {
Self {
Expand Down
11 changes: 6 additions & 5 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ pub fn coerce_types(
}
Ok(input_types.to_vec())
}
AggregateFunction::Median
| AggregateFunction::FirstValue
| AggregateFunction::LastValue => Ok(input_types.to_vec()),
AggregateFunction::FirstValue | AggregateFunction::LastValue => {
Ok(input_types.to_vec())
}
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
AggregateFunction::StringAgg => {
Expand Down Expand Up @@ -355,8 +355,9 @@ pub fn check_arg_count(
);
}
}
TypeSignature::UserDefined => {
// User-defined functions are not validated here
TypeSignature::UserDefined | TypeSignature::Numeric(_) => {
// User-defined signature is validated in `coerce_types`
// Numreic signature is validated in `get_valid_types`
}
_ => {
return internal_err!(
Expand Down
32 changes: 32 additions & 0 deletions datafusion/expr/src/type_coercion/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,38 @@ fn get_valid_types(
.iter()
.map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect())
.collect(),
TypeSignature::Numeric(number) => {
if *number < 1 {
return plan_err!(
"The signature expected at least one argument but received {}",
current_types.len()
);
}
if *number != current_types.len() {
return plan_err!(
"The signature expected {} arguments but received {}",
number,
current_types.len()
);
}

let mut valid_type = current_types.first().unwrap().clone();
for t in current_types.iter().skip(1) {
if let Some(coerced_type) =
comparison_binary_numeric_coercion(&valid_type, t)
{
valid_type = coerced_type;
} else {
return plan_err!(
"{} and {} are not coercible to a common numeric type",
valid_type,
t
);
}
}

vec![vec![valid_type; *number]]
}
TypeSignature::Uniform(number, valid_types) => valid_types
.iter()
.map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect())
Expand Down
1 change: 1 addition & 0 deletions datafusion/functions-aggregate/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ path = "src/lib.rs"
[dependencies]
ahash = { workspace = true }
arrow = { workspace = true }
arrow-schema = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
Expand Down
3 changes: 3 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub mod macros;
pub mod covariance;
pub mod first_last;
pub mod sum;
pub mod median;

use datafusion_common::Result;
use datafusion_execution::FunctionRegistry;
Expand All @@ -70,6 +71,7 @@ pub mod expr_fn {
pub use super::covariance::covar_samp;
pub use super::first_last::first_value;
pub use super::sum::sum;
pub use super::median::median;
}

/// Registers all enabled packages with a [`FunctionRegistry`]
Expand All @@ -79,6 +81,7 @@ pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> {
covariance::covar_samp_udaf(),
sum::sum_udaf(),
covariance::covar_pop_udaf(),
median::median_udaf(),
];

functions.into_iter().try_for_each(|udf| {
Expand Down
Loading

0 comments on commit d78f63b

Please sign in to comment.