diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index cdbc36254e5a..f668119b1ac2 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -68,7 +68,6 @@ use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; use crate::optimizer::simplify_expressions::SimplifyExpressions; use crate::optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; -use crate::optimizer::to_approx_perc::ToApproxPerc; use crate::physical_optimizer::coalesce_batches::CoalesceBatches; use crate::physical_optimizer::merge_exec::AddCoalescePartitionsExec; @@ -1151,10 +1150,6 @@ impl SessionState { Arc::new(FilterPushDown::new()), Arc::new(LimitPushDown::new()), Arc::new(SingleDistinctToGroupBy::new()), - // ToApproxPerc must be applied last because - // it rewrites only the function and may interfere with - // other rules - Arc::new(ToApproxPerc::new()), ], physical_optimizers: vec![ Arc::new(AggregateStatistics::new()), diff --git a/datafusion/core/src/optimizer/mod.rs b/datafusion/core/src/optimizer/mod.rs index cddedfc8a3ee..9f12ecea81df 100644 --- a/datafusion/core/src/optimizer/mod.rs +++ b/datafusion/core/src/optimizer/mod.rs @@ -28,5 +28,4 @@ pub mod optimizer; pub mod projection_push_down; pub mod simplify_expressions; pub mod single_distinct_to_groupby; -pub mod to_approx_perc; pub mod utils; diff --git a/datafusion/core/src/optimizer/to_approx_perc.rs b/datafusion/core/src/optimizer/to_approx_perc.rs deleted file mode 100644 index c33c3f67602a..000000000000 --- a/datafusion/core/src/optimizer/to_approx_perc.rs +++ /dev/null @@ -1,161 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! espression/function to approx_percentile optimizer rule - -use crate::error::Result; -use crate::execution::context::ExecutionProps; -use crate::logical_plan::plan::Aggregate; -use crate::logical_plan::{Expr, LogicalPlan}; -use crate::optimizer::optimizer::OptimizerRule; -use crate::optimizer::utils; -use crate::physical_plan::aggregates; -use crate::scalar::ScalarValue; - -/// espression/function to approx_percentile optimizer rule -/// ```text -/// SELECT F1(s) -/// ... -/// -/// Into -/// -/// SELECT APPROX_PERCENTILE_CONT(s, lit(n)) as "F1(s)" -/// ... -/// ``` -pub struct ToApproxPerc {} - -impl ToApproxPerc { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -impl Default for ToApproxPerc { - fn default() -> Self { - Self::new() - } -} - -fn optimize(plan: &LogicalPlan) -> Result { - match plan { - LogicalPlan::Aggregate(Aggregate { - input, - aggr_expr, - schema, - group_expr, - }) => { - let new_aggr_expr = aggr_expr - .iter() - .map(|agg_expr| replace_with_percentile(agg_expr).unwrap()) - .collect::>(); - - Ok(LogicalPlan::Aggregate(Aggregate { - input: input.clone(), - aggr_expr: new_aggr_expr, - schema: schema.clone(), - group_expr: group_expr.clone(), - })) - } - _ => optimize_children(plan), - } -} - -fn optimize_children(plan: &LogicalPlan) -> Result { - let expr = plan.expressions(); - let inputs = plan.inputs(); - let new_inputs = inputs - .iter() - .map(|plan| optimize(plan)) - .collect::>>()?; - utils::from_plan(plan, &expr, &new_inputs) -} - -fn replace_with_percentile(expr: &Expr) -> Result { - match expr { - Expr::AggregateFunction { - fun, - args, - distinct, - } => { - let mut new_args = args.clone(); - let mut new_func = fun.clone(); - if fun == &aggregates::AggregateFunction::ApproxMedian { - new_args.push(Expr::Literal(ScalarValue::Float64(Some(0.5_f64)))); - new_func = aggregates::AggregateFunction::ApproxPercentileCont; - } - - Ok(Expr::AggregateFunction { - fun: new_func, - args: new_args, - distinct: *distinct, - }) - } - _ => Ok(expr.clone()), - } -} - -impl OptimizerRule for ToApproxPerc { - fn optimize( - &self, - plan: &LogicalPlan, - _execution_props: &ExecutionProps, - ) -> Result { - optimize(plan) - } - fn name(&self) -> &str { - "ToApproxPerc" - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::logical_plan::{col, LogicalPlanBuilder}; - use crate::physical_plan::aggregates; - use crate::test::*; - - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let rule = ToApproxPerc::new(); - let optimized_plan = rule - .optimize(plan, &ExecutionProps::new()) - .expect("failed to optimize plan"); - let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); - assert_eq!(formatted_plan, expected); - } - - #[test] - fn median_1() -> Result<()> { - let table_scan = test_table_scan()?; - let expr = Expr::AggregateFunction { - fun: aggregates::AggregateFunction::ApproxMedian, - distinct: false, - args: vec![col("b")], - }; - - let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(Vec::::new(), vec![expr])? - .build()?; - - // Rewrite to use approx_percentile - let expected = "Aggregate: groupBy=[[]], aggr=[[APPROXPERCENTILECONT(#test.b, Float64(0.5))]] [APPROXMEDIAN(test.b):UInt32;N]\ - \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; - - assert_optimized_plan_eq(&plan, expected); - Ok(()) - } -} diff --git a/datafusion/core/src/sql/planner.rs b/datafusion/core/src/sql/planner.rs index 6404e4bd087d..a9ad914543ab 100644 --- a/datafusion/core/src/sql/planner.rs +++ b/datafusion/core/src/sql/planner.rs @@ -2046,15 +2046,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { window_functions::WindowFunction::AggregateFunction( aggregate_fun, ) => { + let (aggregate_fun, args) = self.aggregate_fn_to_expr( + aggregate_fun, + function, + schema, + )?; + return Ok(Expr::WindowFunction { fun: window_functions::WindowFunction::AggregateFunction( - aggregate_fun.clone(), - ), - args: self.aggregate_fn_to_expr( aggregate_fun, - function, - schema, - )?, + ), + args, partition_by, order_by, window_frame, @@ -2079,7 +2081,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // next, aggregate built-ins if let Ok(fun) = aggregates::AggregateFunction::from_str(&name) { let distinct = function.distinct; - let args = self.aggregate_fn_to_expr(fun.clone(), function, schema)?; + let (fun, args) = self.aggregate_fn_to_expr(fun, function, schema)?; return Ok(Expr::AggregateFunction { fun, distinct, @@ -2173,9 +2175,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun: aggregates::AggregateFunction, function: sqlparser::ast::Function, schema: &DFSchema, - ) -> Result> { - if fun == aggregates::AggregateFunction::Count { - function + ) -> Result<(aggregates::AggregateFunction, Vec)> { + let args = match fun { + aggregates::AggregateFunction::Count => function .args .into_iter() .map(|a| match a { @@ -2185,10 +2187,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(lit(1_u8)), _ => self.sql_fn_arg_to_logical_expr(a, schema), }) - .collect::>>() - } else { - self.function_args_to_expr(function.args, schema) - } + .collect::>>()?, + aggregates::AggregateFunction::ApproxMedian => function + .args + .into_iter() + .map(|a| self.sql_fn_arg_to_logical_expr(a, schema)) + .chain(iter::once(Ok(lit(0.5_f64)))) + .collect::>>()?, + _ => self.function_args_to_expr(function.args, schema)?, + }; + + let fun = match fun { + aggregates::AggregateFunction::ApproxMedian => { + aggregates::AggregateFunction::ApproxPercentileCont + } + _ => fun, + }; + + Ok((fun, args)) } fn sql_interval_to_literal( @@ -3590,6 +3606,15 @@ mod tests { quick_test(sql, expected); } + #[test] + fn select_approx_median() { + let sql = "SELECT approx_median(age) FROM person"; + let expected = "Projection: #APPROXPERCENTILECONT(person.age,Float64(0.5))\ + \n Aggregate: groupBy=[[]], aggr=[[APPROXPERCENTILECONT(#person.age, Float64(0.5))]]\ + \n TableScan: person projection=None"; + quick_test(sql, expected); + } + #[test] fn select_scalar_func() { let sql = "SELECT sqrt(age) FROM person"; @@ -4344,6 +4369,17 @@ mod tests { quick_test(sql, expected); } + #[test] + fn approx_median_window() { + let sql = + "SELECT order_id, APPROX_MEDIAN(qty) OVER(PARTITION BY order_id) from orders"; + let expected = "\ + Projection: #orders.order_id, #APPROXPERCENTILECONT(orders.qty,Float64(0.5)) PARTITION BY [#orders.order_id]\ + \n WindowAggr: windowExpr=[[APPROXPERCENTILECONT(#orders.qty, Float64(0.5)) PARTITION BY [#orders.order_id]]]\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + #[test] fn select_typedstring() { let sql = "SELECT date '2020-12-10' AS date FROM person";