From 6a8993ffa777b1b258d2d8df22ddc48251b528c3 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 16 Feb 2025 15:52:57 +0800 Subject: [PATCH 1/9] udaf schema_name --- datafusion/core/src/physical_planner.rs | 16 ++- datafusion/expr/src/expr.rs | 107 +++++++++--------- datafusion/expr/src/expr_fn.rs | 8 +- datafusion/expr/src/expr_schema.rs | 9 +- datafusion/expr/src/tree_node.rs | 20 ++-- datafusion/expr/src/udaf.rs | 50 +++++++- datafusion/functions-nested/src/planner.rs | 45 ++++---- .../src/analyzer/count_wildcard_rule.rs | 7 +- .../src/analyzer/resolve_grouping_function.rs | 3 +- .../optimizer/src/analyzer/type_coercion.rs | 17 +-- .../src/single_distinct_to_groupby.rs | 18 +-- datafusion/sql/src/unparser/expr.rs | 15 ++- 12 files changed, 193 insertions(+), 122 deletions(-) diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index 2303574e88af..bce1aab16e5e 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -70,7 +70,8 @@ use datafusion_common::{ }; use datafusion_expr::dml::{CopyTo, InsertOp}; use datafusion_expr::expr::{ - physical_name, AggregateFunction, Alias, GroupingSet, WindowFunction, + physical_name, AggregateFunction, AggregateFunctionParams, Alias, GroupingSet, + WindowFunction, }; use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; @@ -1579,11 +1580,14 @@ pub fn create_aggregate_expr_with_name_and_maybe_filter( match e { Expr::AggregateFunction(AggregateFunction { func, - distinct, - args, - filter, - order_by, - null_treatment, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, }) => { let name = if let Some(name) = name { name diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 305519a1f4b4..f1cc365dfc88 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -696,7 +696,11 @@ impl<'a> TreeNodeContainer<'a, Expr> for Sort { pub struct AggregateFunction { /// Name of the function pub func: Arc, - /// List of expressions to feed to the functions as arguments + pub params: AggregateFunctionParams, +} + +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct AggregateFunctionParams { pub args: Vec, /// Whether this is a DISTINCT aggregation or not pub distinct: bool, @@ -719,11 +723,13 @@ impl AggregateFunction { ) -> Self { Self { func, - args, - distinct, - filter, - order_by, - null_treatment, + params: AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, } } } @@ -1864,19 +1870,25 @@ impl NormalizeEq for Expr { ( Expr::AggregateFunction(AggregateFunction { func: self_func, - args: self_args, - distinct: self_distinct, - filter: self_filter, - order_by: self_order_by, - null_treatment: self_null_treatment, + params: + AggregateFunctionParams { + args: self_args, + distinct: self_distinct, + filter: self_filter, + order_by: self_order_by, + null_treatment: self_null_treatment, + }, }), Expr::AggregateFunction(AggregateFunction { func: other_func, - args: other_args, - distinct: other_distinct, - filter: other_filter, - order_by: other_order_by, - null_treatment: other_null_treatment, + params: + AggregateFunctionParams { + args: other_args, + distinct: other_distinct, + filter: other_filter, + order_by: other_order_by, + null_treatment: other_null_treatment, + }, }), ) => { self_func.name() == other_func.name() @@ -2154,11 +2166,14 @@ impl HashNode for Expr { } Expr::AggregateFunction(AggregateFunction { func, - args: _args, - distinct, - filter: _filter, - order_by: _order_by, - null_treatment, + params: + AggregateFunctionParams { + args: _args, + distinct, + filter: _, + order_by: _, + null_treatment, + }, }) => { func.hash(state); distinct.hash(state); @@ -2264,35 +2279,15 @@ impl Display for SchemaDisplay<'_> { | Expr::Placeholder(_) | Expr::Wildcard { .. } => write!(f, "{}", self.0), - Expr::AggregateFunction(AggregateFunction { - func, - args, - distinct, - filter, - order_by, - null_treatment, - }) => { - write!( - f, - "{}({}{})", - func.name(), - if *distinct { "DISTINCT " } else { "" }, - schema_name_from_exprs_comma_separated_without_space(args)? - )?; - - if let Some(null_treatment) = null_treatment { - write!(f, " {}", null_treatment)?; + Expr::AggregateFunction(AggregateFunction { func, params }) => { + match func.schema_name(params) { + Ok(name) => { + write!(f, "{name}") + } + Err(e) => { + write!(f, "got error from schema_name {}", e) + } } - - if let Some(filter) = filter { - write!(f, " FILTER (WHERE {filter})")?; - }; - - if let Some(order_by) = order_by { - write!(f, " ORDER BY [{}]", schema_name_from_sorts(order_by)?)?; - }; - - Ok(()) } // Expr is not shown since it is aliased Expr::Alias(Alias { @@ -2655,12 +2650,14 @@ impl Display for Expr { } Expr::AggregateFunction(AggregateFunction { func, - distinct, - ref args, - filter, - order_by, - null_treatment, - .. + params: + AggregateFunctionParams { + ref args, + distinct, + filter, + order_by, + null_treatment, + }, }) => { fmt_function(f, func.name(), *distinct, args, true)?; if let Some(nt) = null_treatment { diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index a2de5e7b259f..91d2b379af60 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -830,10 +830,10 @@ impl ExprFuncBuilder { let fun_expr = match fun { ExprFuncKind::Aggregate(mut udaf) => { - udaf.order_by = order_by; - udaf.filter = filter.map(Box::new); - udaf.distinct = distinct; - udaf.null_treatment = null_treatment; + udaf.params.order_by = order_by; + udaf.params.filter = filter.map(Box::new); + udaf.params.distinct = distinct; + udaf.params.null_treatment = null_treatment; Expr::AggregateFunction(udaf) } ExprFuncKind::Window(mut udwf) => { diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 49791427131f..becb7c14397d 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -17,8 +17,8 @@ use super::{Between, Expr, Like}; use crate::expr::{ - AggregateFunction, Alias, BinaryExpr, Cast, InList, InSubquery, Placeholder, - ScalarFunction, TryCast, Unnest, WindowFunction, + AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, + InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, }; use crate::type_coercion::functions::{ data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, @@ -153,7 +153,10 @@ impl ExprSchemable for Expr { Expr::WindowFunction(window_function) => self .data_type_and_nullable_with_window_function(schema, window_function) .map(|(return_type, _)| return_type), - Expr::AggregateFunction(AggregateFunction { func, args, .. }) => { + Expr::AggregateFunction(AggregateFunction { + func, + params: AggregateFunctionParams { args, .. }, + }) => { let data_types = args .iter() .map(|e| e.get_type(schema)) diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index eacace5ed046..7801d564135e 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -18,8 +18,9 @@ //! Tree node implementation for Logical Expressions use crate::expr::{ - AggregateFunction, Alias, Between, BinaryExpr, Case, Cast, GroupingSet, InList, - InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, + GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, + WindowFunction, }; use crate::{Expr, ExprFunctionExt}; @@ -87,7 +88,7 @@ impl TreeNode for Expr { }) => (expr, low, high).apply_ref_elements(f), Expr::Case(Case { expr, when_then_expr, else_expr }) => (expr, when_then_expr, else_expr).apply_ref_elements(f), - Expr::AggregateFunction(AggregateFunction { args, filter, order_by, .. }) => + Expr::AggregateFunction(AggregateFunction { params: AggregateFunctionParams { args, filter, order_by, ..}, .. }) => (args, filter, order_by).apply_ref_elements(f), Expr::WindowFunction(WindowFunction { args, @@ -241,12 +242,15 @@ impl TreeNode for Expr { }, ), Expr::AggregateFunction(AggregateFunction { - args, func, - distinct, - filter, - order_by, - null_treatment, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, }) => (args, filter, order_by).map_elements(f)?.map_data( |(new_args, new_filter, new_order_by)| { Ok(Expr::AggregateFunction(AggregateFunction::new_udf( diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 7ffc6623ea92..32fc566d1c14 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::cmp::Ordering; -use std::fmt::{self, Debug, Formatter}; +use std::fmt::{self, Debug, Formatter, Write}; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; use std::vec; @@ -29,7 +29,10 @@ use arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, not_impl_err, Result, ScalarValue, Statistics}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; -use crate::expr::AggregateFunction; +use crate::expr::{ + schema_name_from_exprs_comma_separated_without_space, schema_name_from_sorts, + AggregateFunction, AggregateFunctionParams, +}; use crate::function::{ AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs, }; @@ -165,6 +168,10 @@ impl AggregateUDF { self.inner.name() } + pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result { + self.inner.schema_name(params) + } + pub fn is_nullable(&self) -> bool { self.inner.is_nullable() } @@ -382,6 +389,45 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Returns this function's name fn name(&self) -> &str; + /// Returns the name of the column this expression would create + /// + /// See [`Expr::schema_name`] for details + fn schema_name(&self, params: &AggregateFunctionParams) -> Result { + let AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + } = params; + + let mut schema_name = String::new(); + + schema_name.write_fmt(format_args!( + "{}({}{})", + self.name(), + if *distinct { "DISTINCT " } else { "" }, + schema_name_from_exprs_comma_separated_without_space(args)? + ))?; + + if let Some(null_treatment) = null_treatment { + schema_name.write_fmt(format_args!(" {}", null_treatment))?; + } + + if let Some(filter) = filter { + schema_name.write_fmt(format_args!(" FILTER (WHERE {filter})"))?; + }; + + if let Some(order_by) = order_by { + schema_name.write_fmt(format_args!( + " ORDER BY [{}]", + schema_name_from_sorts(order_by)? + ))?; + }; + + Ok(schema_name) + } + /// Returns the function's [`Signature`] for information about what input /// types are accepted and the function's Volatility. fn signature(&self) -> &Signature; diff --git a/datafusion/functions-nested/src/planner.rs b/datafusion/functions-nested/src/planner.rs index 5ca51ac20f1e..d55176a42c9a 100644 --- a/datafusion/functions-nested/src/planner.rs +++ b/datafusion/functions-nested/src/planner.rs @@ -17,8 +17,11 @@ //! SQL planning extensions like [`NestedFunctionPlanner`] and [`FieldAccessPlanner`] +use std::sync::Arc; + use datafusion_common::{plan_err, utils::list_ndims, DFSchema, Result}; -use datafusion_expr::expr::ScalarFunction; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, ScalarFunction}; +use datafusion_expr::AggregateUDF; use datafusion_expr::{ planner::{ExprPlanner, PlannerResult, RawBinaryExpr, RawFieldAccessExpr}, sqlparser, Expr, ExprSchemable, GetFieldAccess, @@ -150,22 +153,26 @@ impl ExprPlanner for FieldAccessPlanner { GetFieldAccess::ListIndex { key: index } => { match expr { // Special case for array_agg(expr)[index] to NTH_VALUE(expr, index) - Expr::AggregateFunction(agg_func) if is_array_agg(&agg_func) => { - Ok(PlannerResult::Planned(Expr::AggregateFunction( - datafusion_expr::expr::AggregateFunction::new_udf( - nth_value_udaf(), - agg_func - .args - .into_iter() - .chain(std::iter::once(*index)) - .collect(), - agg_func.distinct, - agg_func.filter, - agg_func.order_by, - agg_func.null_treatment, - ), - ))) - } + Expr::AggregateFunction(AggregateFunction { + func, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, + }) if is_array_agg(&func) => Ok(PlannerResult::Planned( + Expr::AggregateFunction(AggregateFunction::new_udf( + nth_value_udaf(), + args.into_iter().chain(std::iter::once(*index)).collect(), + distinct, + filter, + order_by, + null_treatment, + )), + )), _ => Ok(PlannerResult::Planned(array_element(expr, *index))), } } @@ -184,6 +191,6 @@ impl ExprPlanner for FieldAccessPlanner { } } -fn is_array_agg(agg_func: &datafusion_expr::expr::AggregateFunction) -> bool { - agg_func.func.name() == "array_agg" +fn is_array_agg(func: &Arc) -> bool { + func.name() == "array_agg" } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 95b6f9dc764f..7e73474cf6f5 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -21,7 +21,7 @@ use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::Result; -use datafusion_expr::expr::{AggregateFunction, WindowFunction}; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams, WindowFunction}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{lit, Expr, LogicalPlan, WindowFunctionDefinition}; @@ -55,8 +55,7 @@ fn is_count_star_aggregate(aggregate_function: &AggregateFunction) -> bool { matches!(aggregate_function, AggregateFunction { func, - args, - .. + params: AggregateFunctionParams { args, .. }, } if func.name() == "count" && (args.len() == 1 && is_wildcard(&args[0]) || args.is_empty())) } @@ -81,7 +80,7 @@ fn analyze_internal(plan: LogicalPlan) -> Result> { Expr::AggregateFunction(mut aggregate_function) if is_count_star_aggregate(&aggregate_function) => { - aggregate_function.args = vec![lit(COUNT_STAR_EXPANSION)]; + aggregate_function.params.args = vec![lit(COUNT_STAR_EXPANSION)]; Ok(Transformed::yes(Expr::AggregateFunction( aggregate_function, ))) diff --git a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs index 16ebb8cd3972..f8a818563609 100644 --- a/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs +++ b/datafusion/optimizer/src/analyzer/resolve_grouping_function.rs @@ -163,6 +163,7 @@ fn validate_args( group_by_expr: &HashMap<&Expr, usize>, ) -> Result<()> { let expr_not_in_group_by = function + .params .args .iter() .find(|expr| !group_by_expr.contains_key(expr)); @@ -183,7 +184,7 @@ fn grouping_function_on_id( is_grouping_set: bool, ) -> Result { validate_args(function, group_by_expr)?; - let args = &function.args; + let args = &function.params.args; // Postgres allows grouping function for group by without grouping sets, the result is then // always 0 diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 85fc9b31bcdd..fd20c9fa5409 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -33,8 +33,8 @@ use datafusion_common::{ DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::expr::{ - self, Alias, Between, BinaryExpr, Case, Exists, InList, InSubquery, Like, - ScalarFunction, Sort, WindowFunction, + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Exists, InList, + InSubquery, Like, ScalarFunction, Sort, WindowFunction, }; use datafusion_expr::expr_rewriter::coerce_plan_expr_for_schema; use datafusion_expr::expr_schema::cast_subquery; @@ -506,11 +506,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } Expr::AggregateFunction(expr::AggregateFunction { func, - args, - distinct, - filter, - order_by, - null_treatment, + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + }, }) => { let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index c8f3a4bc7859..191377fc2759 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -26,6 +26,7 @@ use datafusion_common::{ internal_err, tree_node::Transformed, DataFusionError, HashSet, Result, }; use datafusion_expr::builder::project; +use datafusion_expr::expr::AggregateFunctionParams; use datafusion_expr::{ col, expr::AggregateFunction, @@ -68,11 +69,14 @@ fn is_single_distinct_agg(aggr_expr: &[Expr]) -> Result { for expr in aggr_expr { if let Expr::AggregateFunction(AggregateFunction { func, - distinct, - args, - filter, - order_by, - null_treatment: _, + params: + AggregateFunctionParams { + distinct, + args, + filter, + order_by, + null_treatment: _, + }, }) = expr { if filter.is_some() || order_by.is_some() { @@ -179,9 +183,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { .map(|aggr_expr| match aggr_expr { Expr::AggregateFunction(AggregateFunction { func, - mut args, - distinct, - .. + params: AggregateFunctionParams { mut args, distinct, .. } }) => { if distinct { if args.len() != 1 { diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 72618c2b6ab4..6491671d84a5 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::expr::Unnest; +use datafusion_expr::expr::{AggregateFunctionParams, Unnest}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName, @@ -284,9 +284,15 @@ impl Unparser<'_> { }), Expr::AggregateFunction(agg) => { let func_name = agg.func.name(); + let AggregateFunctionParams { + distinct, + args, + filter, + .. + } = &agg.params; - let args = self.function_args_to_sql(&agg.args)?; - let filter = match &agg.filter { + let args = self.function_args_to_sql(args)?; + let filter = match filter { Some(filter) => Some(Box::new(self.expr_to_sql_inner(filter)?)), None => None, }; @@ -297,8 +303,7 @@ impl Unparser<'_> { span: Span::empty(), }]), args: ast::FunctionArguments::List(ast::FunctionArgumentList { - duplicate_treatment: agg - .distinct + duplicate_treatment: distinct .then_some(ast::DuplicateTreatment::Distinct), args, clauses: vec![], From a233dc7bc0fc73037bc729acb85c5614cccdbeee Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 16 Feb 2025 15:54:11 +0800 Subject: [PATCH 2/9] doc --- datafusion/expr/src/udaf.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 32fc566d1c14..39ad8de526ce 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -168,6 +168,7 @@ impl AggregateUDF { self.inner.name() } + /// See [`AggregateUDFImpl::schema_name`] for more details. pub fn schema_name(&self, params: &AggregateFunctionParams) -> Result { self.inner.schema_name(params) } From ae2af2d856bb011d2290dec379640dcc6ef6bab2 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 16 Feb 2025 16:19:00 +0800 Subject: [PATCH 3/9] fix proto --- datafusion/proto/src/logical_plan/to_proto.rs | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 6d1d4f30610c..d4cadcbea6a7 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -22,8 +22,7 @@ use datafusion_common::{TableReference, UnnestOptions}; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ - self, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, Placeholder, - ScalarFunction, Unnest, + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, Placeholder, ScalarFunction, Unnest }; use datafusion_expr::WriteOp; use datafusion_expr::{ @@ -348,11 +347,13 @@ pub fn serialize_expr( } Expr::AggregateFunction(expr::AggregateFunction { ref func, - ref args, - ref distinct, - ref filter, - ref order_by, - null_treatment: _, + params: AggregateFunctionParams { + ref args, + ref distinct, + ref filter, + ref order_by, + null_treatment: _, + }, }) => { let mut buf = Vec::new(); let _ = codec.try_encode_udaf(func, &mut buf); From 01a6bcb8bb9a65014549f9d809dc0f2abb3eeec2 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 16 Feb 2025 16:19:10 +0800 Subject: [PATCH 4/9] fmt --- datafusion/proto/src/logical_plan/to_proto.rs | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index d4cadcbea6a7..228437271694 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -22,7 +22,8 @@ use datafusion_common::{TableReference, UnnestOptions}; use datafusion_expr::dml::InsertOp; use datafusion_expr::expr::{ - self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, Like, Placeholder, ScalarFunction, Unnest + self, AggregateFunctionParams, Alias, Between, BinaryExpr, Cast, GroupingSet, InList, + Like, Placeholder, ScalarFunction, Unnest, }; use datafusion_expr::WriteOp; use datafusion_expr::{ @@ -347,13 +348,14 @@ pub fn serialize_expr( } Expr::AggregateFunction(expr::AggregateFunction { ref func, - params: AggregateFunctionParams { - ref args, - ref distinct, - ref filter, - ref order_by, - null_treatment: _, - }, + params: + AggregateFunctionParams { + ref args, + ref distinct, + ref filter, + ref order_by, + null_treatment: _, + }, }) => { let mut buf = Vec::new(); let _ = codec.try_encode_udaf(func, &mut buf); From 46943e18bad587a9229f13e1902706b74694e39b Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 16 Feb 2025 16:40:39 +0800 Subject: [PATCH 5/9] fix --- datafusion-examples/examples/advanced_udaf.rs | 10 +++++----- datafusion/core/tests/execution/logical_plan.rs | 4 +++- datafusion/substrait/src/logical_plan/producer.rs | 14 ++++++++------ 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/datafusion-examples/examples/advanced_udaf.rs b/datafusion-examples/examples/advanced_udaf.rs index fd65c3352bbc..9cda726db719 100644 --- a/datafusion-examples/examples/advanced_udaf.rs +++ b/datafusion-examples/examples/advanced_udaf.rs @@ -423,11 +423,11 @@ impl AggregateUDFImpl for SimplifiedGeoMeanUdaf { // In real-world scenarios, you might create UDFs from built-in expressions. Ok(Expr::AggregateFunction(AggregateFunction::new_udf( Arc::new(AggregateUDF::from(GeoMeanUdaf::new())), - aggregate_function.args, - aggregate_function.distinct, - aggregate_function.filter, - aggregate_function.order_by, - aggregate_function.null_treatment, + aggregate_function.params.args, + aggregate_function.params.distinct, + aggregate_function.params.filter, + aggregate_function.params.order_by, + aggregate_function.params.null_treatment, ))) }; Some(Box::new(simplify)) diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index a521902389dc..95d797ee7ca7 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -20,7 +20,7 @@ use arrow::datatypes::{DataType, Field}; use datafusion::execution::session_state::SessionStateBuilder; use datafusion_common::{Column, DFSchema, Result, ScalarValue, Spans}; use datafusion_execution::TaskContext; -use datafusion_expr::expr::AggregateFunction; +use datafusion_expr::expr::{AggregateFunction, AggregateFunctionParams}; use datafusion_expr::logical_plan::{LogicalPlan, Values}; use datafusion_expr::{Aggregate, AggregateUDF, Expr}; use datafusion_functions_aggregate::count::Count; @@ -60,11 +60,13 @@ async fn count_only_nulls() -> Result<()> { vec![], vec![Expr::AggregateFunction(AggregateFunction { func: Arc::new(AggregateUDF::new_from_impl(Count::new())), + params: AggregateFunctionParams { args: vec![input_col_ref], distinct: false, filter: None, order_by: None, null_treatment: None, + } })], )?); diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index 42c226174932..bc1a847308e7 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -52,7 +52,7 @@ use datafusion::common::{ use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::SessionState; use datafusion::logical_expr::expr::{ - Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction, + AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -1208,11 +1208,13 @@ pub fn from_aggregate_function( ) -> Result { let expr::AggregateFunction { func, - args, - distinct, - filter, - order_by, - null_treatment: _null_treatment, + params: AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment: _null_treatment, + } } = agg_fn; let sorts = if let Some(order_by) = order_by { order_by From f2940127238486729633297b7d8894d2cb58a132 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Sun, 16 Feb 2025 17:02:00 +0800 Subject: [PATCH 6/9] fmt --- .../core/tests/execution/logical_plan.rs | 12 ++++++------ .../substrait/src/logical_plan/producer.rs | 18 ++++++++++-------- 2 files changed, 16 insertions(+), 14 deletions(-) diff --git a/datafusion/core/tests/execution/logical_plan.rs b/datafusion/core/tests/execution/logical_plan.rs index 95d797ee7ca7..a17bb5eec8a3 100644 --- a/datafusion/core/tests/execution/logical_plan.rs +++ b/datafusion/core/tests/execution/logical_plan.rs @@ -61,12 +61,12 @@ async fn count_only_nulls() -> Result<()> { vec![Expr::AggregateFunction(AggregateFunction { func: Arc::new(AggregateUDF::new_from_impl(Count::new())), params: AggregateFunctionParams { - args: vec![input_col_ref], - distinct: false, - filter: None, - order_by: None, - null_treatment: None, - } + args: vec![input_col_ref], + distinct: false, + filter: None, + order_by: None, + null_treatment: None, + }, })], )?); diff --git a/datafusion/substrait/src/logical_plan/producer.rs b/datafusion/substrait/src/logical_plan/producer.rs index bc1a847308e7..d795a869568b 100644 --- a/datafusion/substrait/src/logical_plan/producer.rs +++ b/datafusion/substrait/src/logical_plan/producer.rs @@ -52,7 +52,8 @@ use datafusion::common::{ use datafusion::execution::registry::SerializerRegistry; use datafusion::execution::SessionState; use datafusion::logical_expr::expr::{ - AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, InSubquery, WindowFunction + AggregateFunctionParams, Alias, BinaryExpr, Case, Cast, GroupingSet, InList, + InSubquery, WindowFunction, }; use datafusion::logical_expr::{expr, Between, JoinConstraint, LogicalPlan, Operator}; use datafusion::prelude::Expr; @@ -1208,13 +1209,14 @@ pub fn from_aggregate_function( ) -> Result { let expr::AggregateFunction { func, - params: AggregateFunctionParams { - args, - distinct, - filter, - order_by, - null_treatment: _null_treatment, - } + params: + AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment: _null_treatment, + }, } = agg_fn; let sorts = if let Some(order_by) = order_by { order_by From 2fee41e3e59fb7fafe4a6bd3c334e5e0fd953972 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 17 Feb 2025 08:28:47 +0800 Subject: [PATCH 7/9] doc --- datafusion/expr/src/udaf.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 39ad8de526ce..0d568d677c1f 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -393,6 +393,8 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// Returns the name of the column this expression would create /// /// See [`Expr::schema_name`] for details + /// + /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [column3] fn schema_name(&self, params: &AggregateFunctionParams) -> Result { let AggregateFunctionParams { args, From 156f6037a9392effb67ac1b2c3caf43783503443 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 17 Feb 2025 08:46:47 +0800 Subject: [PATCH 8/9] add displayname --- datafusion/expr/src/expr.rs | 29 ++++++--------------- datafusion/expr/src/udaf.rs | 51 +++++++++++++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 21 deletions(-) diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index f1cc365dfc88..84ff36a9317d 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -2648,28 +2648,15 @@ impl Display for Expr { )?; Ok(()) } - Expr::AggregateFunction(AggregateFunction { - func, - params: - AggregateFunctionParams { - ref args, - distinct, - filter, - order_by, - null_treatment, - }, - }) => { - fmt_function(f, func.name(), *distinct, args, true)?; - if let Some(nt) = null_treatment { - write!(f, " {}", nt)?; - } - if let Some(fe) = filter { - write!(f, " FILTER (WHERE {fe})")?; - } - if let Some(ob) = order_by { - write!(f, " ORDER BY [{}]", expr_vec_fmt!(ob))?; + Expr::AggregateFunction(AggregateFunction { func, params }) => { + match func.display_name(params) { + Ok(name) => { + write!(f, "{}", name) + } + Err(e) => { + write!(f, "got error from display_name {}", e) + } } - Ok(()) } Expr::Between(Between { expr, diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 0d568d677c1f..1ad1af96ff35 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -173,6 +173,11 @@ impl AggregateUDF { self.inner.schema_name(params) } + /// See [`AggregateUDFImpl::display_name`] for more details. + pub fn display_name(&self, params: &AggregateFunctionParams) -> Result { + self.inner.display_name(params) + } + pub fn is_nullable(&self) -> bool { self.inner.is_nullable() } @@ -431,6 +436,52 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { Ok(schema_name) } + /// Returns the user-defined display name of function, given the arguments + /// + /// This can be used to customize the output column name generated by this + /// function. + /// + /// Defaults to `function_name([DISTINCT] args[0], args[1], ...) [null_treatment] [filter] [order_by [..]]` + fn display_name(&self, params: &AggregateFunctionParams) -> Result { + let AggregateFunctionParams { + args, + distinct, + filter, + order_by, + null_treatment, + } = params; + + let mut schema_name = String::new(); + + schema_name.write_fmt(format_args!( + "{}({}{})", + self.name(), + if *distinct { "DISTINCT " } else { "" }, + args.iter() + .map(|arg| format!("{arg}")) + .collect::>() + .join(", ") + ))?; + + if let Some(nt) = null_treatment { + schema_name.write_fmt(format_args!(" {}", nt))?; + } + if let Some(fe) = filter { + schema_name.write_fmt(format_args!(" FILTER (WHERE {fe})"))?; + } + if let Some(ob) = order_by { + schema_name.write_fmt(format_args!( + " ORDER BY [{}]", + ob.iter() + .map(|o| format!("{o}")) + .collect::>() + .join(", ") + ))?; + } + + Ok(schema_name) + } + /// Returns the function's [`Signature`] for information about what input /// types are accepted and the function's Volatility. fn signature(&self) -> &Signature; From 52eeeafa9f2b42c60e48914bba2b88d76f1b9303 Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 17 Feb 2025 09:10:14 +0800 Subject: [PATCH 9/9] doc --- datafusion/expr/src/udaf.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/udaf.rs b/datafusion/expr/src/udaf.rs index 1ad1af96ff35..bf8f34f949e0 100644 --- a/datafusion/expr/src/udaf.rs +++ b/datafusion/expr/src/udaf.rs @@ -399,7 +399,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// /// See [`Expr::schema_name`] for details /// - /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [column3] + /// Example of schema_name: count(DISTINCT column1) FILTER (WHERE column2 > 10) ORDER BY [..] fn schema_name(&self, params: &AggregateFunctionParams) -> Result { let AggregateFunctionParams { args, @@ -441,7 +441,7 @@ pub trait AggregateUDFImpl: Debug + Send + Sync { /// This can be used to customize the output column name generated by this /// function. /// - /// Defaults to `function_name([DISTINCT] args[0], args[1], ...) [null_treatment] [filter] [order_by [..]]` + /// Defaults to `function_name([DISTINCT] column1, column2, ..) [null_treatment] [filter] [order_by [..]]` fn display_name(&self, params: &AggregateFunctionParams) -> Result { let AggregateFunctionParams { args,