From 76e7e6db248ab14e4747ee3aa070ad197cb1d402 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Mon, 31 May 2021 22:17:30 +0800 Subject: [PATCH] fix window expression with alias --- ballista/rust/core/proto/ballista.proto | 14 +- .../core/src/serde/logical_plan/from_proto.rs | 12 +- .../core/src/serde/logical_plan/to_proto.rs | 39 ++-- .../src/serde/physical_plan/from_proto.rs | 14 +- datafusion/src/logical_plan/builder.rs | 24 +- datafusion/src/logical_plan/expr.rs | 23 +- datafusion/src/logical_plan/plan.rs | 28 +-- .../src/optimizer/projection_push_down.rs | 45 ++-- datafusion/src/optimizer/utils.rs | 39 ++-- datafusion/src/physical_plan/planner.rs | 7 +- datafusion/src/sql/mod.rs | 2 +- datafusion/src/sql/planner.rs | 220 +++++++++++++++--- datafusion/src/sql/utils.rs | 211 ++++++++++++++++- 13 files changed, 508 insertions(+), 170 deletions(-) diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 03872147b797..d21cbf694b9d 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -174,6 +174,12 @@ message WindowExprNode { // udaf = 3 } LogicalExprNode expr = 4; + // repeated LogicalExprNode partition_by = 5; + repeated LogicalExprNode order_by = 6; + // repeated LogicalExprNode filter = 7; + // oneof window_frame { + // WindowFrame frame = 8; + // } } message BetweenNode { @@ -317,14 +323,6 @@ message AggregateNode { message WindowNode { LogicalPlanNode input = 1; repeated LogicalExprNode window_expr = 2; - repeated LogicalExprNode partition_by_expr = 3; - repeated LogicalExprNode order_by_expr = 4; - // "optional" keyword is stable in protoc 3.15 but prost is still on 3.14 (see https://github.com/danburkert/prost/issues/430) - // this syntax is ugly but is binary compatible with the "optional" keyword (see https://stackoverflow.com/questions/42622015/how-to-define-an-optional-field-in-protobuf-3) - oneof window_frame { - WindowFrame frame = 5; - } - // TODO add filter by expr } enum WindowFrameUnits { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 48471263885f..522d60cb8a05 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -98,9 +98,7 @@ impl TryInto for &protobuf::LogicalPlanNode { // // FIXME: parse the window_frame data // let window_frame = None; LogicalPlanBuilder::from(&input) - .window( - window_expr, /* filter_by_expr, partition_by_expr, order_by_expr, window_frame*/ - )? + .window(window_expr)? .build() .map_err(|e| e.into()) } @@ -924,6 +922,12 @@ impl TryInto for &protobuf::LogicalExprNode { .window_function .as_ref() .ok_or_else(|| proto_error("Received empty window function"))?; + let order_by = expr + .order_by + .iter() + .map(|e| e.try_into()) + .into_iter() + .collect::, _>>()?; match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { let aggr_function = protobuf::AggregateFunction::from_i32(*i) @@ -939,6 +943,7 @@ impl TryInto for &protobuf::LogicalExprNode { AggregateFunction::from(aggr_function), ), args: vec![parse_required_expr(&expr.expr)?], + order_by, }) } window_expr_node::WindowFunction::BuiltInFunction(i) => { @@ -957,6 +962,7 @@ impl TryInto for &protobuf::LogicalExprNode { BuiltInWindowFunction::from(built_in_function), ), args: vec![parse_required_expr(&expr.expr)?], + order_by, }) } } diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index e1c0c5e44df6..088e93120e4f 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -761,27 +761,9 @@ impl TryInto for &LogicalPlan { }) } LogicalPlan::Window { - input, - window_expr, - // FIXME implement next - // filter_by_expr, - // FIXME implement next - // partition_by_expr, - // FIXME implement next - // order_by_expr, - // FIXME implement next - // window_frame, - .. + input, window_expr, .. } => { let input: protobuf::LogicalPlanNode = input.as_ref().try_into()?; - // FIXME: implement - // let filter_by_expr = vec![]; - // FIXME: implement - let partition_by_expr = vec![]; - // FIXME: implement - let order_by_expr = vec![]; - // FIXME: implement - let window_frame = None; Ok(protobuf::LogicalPlanNode { logical_plan_type: Some(LogicalPlanType::Window(Box::new( protobuf::WindowNode { @@ -789,10 +771,7 @@ impl TryInto for &LogicalPlan { window_expr: window_expr .iter() .map(|expr| expr.try_into()) - .collect::, BallistaError>>()?, - partition_by_expr, - order_by_expr, - window_frame, + .collect::, _>>()?, }, ))), }) @@ -811,11 +790,11 @@ impl TryInto for &LogicalPlan { group_expr: group_expr .iter() .map(|expr| expr.try_into()) - .collect::, BallistaError>>()?, + .collect::, _>>()?, aggr_expr: aggr_expr .iter() .map(|expr| expr.try_into()) - .collect::, BallistaError>>()?, + .collect::, _>>()?, }, ))), }) @@ -1024,7 +1003,10 @@ impl TryInto for &Expr { }) } Expr::WindowFunction { - ref fun, ref args, .. + ref fun, + ref args, + ref order_by, + .. } => { let window_function = match fun { WindowFunction::AggregateFunction(fun) => { @@ -1039,9 +1021,14 @@ impl TryInto for &Expr { } }; let arg = &args[0]; + let order_by = order_by + .iter() + .map(|e| e.try_into()) + .collect::, _>>()?; let window_expr = Box::new(protobuf::WindowExprNode { expr: Some(Box::new(arg.try_into()?)), window_function: Some(window_function), + order_by, }); Ok(protobuf::LogicalExprNode { expr_type: Some(ExprType::WindowExpr(window_expr)), diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 7f98a8378b0b..c19739a6b061 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -233,7 +233,11 @@ impl TryInto> for &protobuf::PhysicalPlanNode { for (expr, name) in &window_agg_expr { match expr { - Expr::WindowFunction { fun, args } => { + Expr::WindowFunction { + fun, + args, + order_by, + } => { let arg = df_planner .create_physical_expr( &args[0], @@ -243,12 +247,16 @@ impl TryInto> for &protobuf::PhysicalPlanNode { .map_err(|e| { BallistaError::General(format!("{:?}", e)) })?; - physical_window_expr.push(create_window_expr( + if !order_by.is_empty() { + return Err(BallistaError::NotImplemented("Window function with order by is not yet implemented".to_owned())); + } + let window_expr = create_window_expr( &fun, &[arg], &physical_schema, name.to_owned(), - )?); + )?; + physical_window_expr.push(window_expr); } _ => { return Err(BallistaError::General( diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 71de48cdb8f8..dc80a41c0c01 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -297,23 +297,7 @@ impl LogicalPlanBuilder { /// - https://github.com/apache/arrow-datafusion/issues/299 with partition clause /// - https://github.com/apache/arrow-datafusion/issues/360 with order by /// - https://github.com/apache/arrow-datafusion/issues/361 with window frame - pub fn window( - &self, - window_expr: impl IntoIterator, - // FIXME: implement next - // filter_by_expr: impl IntoIterator, - // FIXME: implement next - // partition_by_expr: impl IntoIterator, - // FIXME: implement next - // order_by_expr: impl IntoIterator, - // FIXME: implement next - // window_frame: Option, - ) -> Result { - let window_expr = window_expr.into_iter().collect::>(); - // FIXME: implement next - // let partition_by_expr = partition_by_expr.into_iter().collect::>(); - // FIXME: implement next - // let order_by_expr = order_by_expr.into_iter().collect::>(); + pub fn window(&self, window_expr: Vec) -> Result { let all_expr = window_expr.iter(); validate_unique_names("Windows", all_expr.clone(), self.plan.schema())?; @@ -323,12 +307,6 @@ impl LogicalPlanBuilder { Ok(Self::from(&LogicalPlan::Window { input: Arc::new(self.plan.clone()), - // FIXME implement next - // partition_by_expr, - // FIXME implement next - // order_by_expr, - // FIXME implement next - // window_frame, window_expr, schema: Arc::new(DFSchema::new(window_fields)?), })) diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 29723e73d25c..5103d5dc5051 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -197,6 +197,8 @@ pub enum Expr { fun: window_functions::WindowFunction, /// List of expressions to feed to the functions as arguments args: Vec, + /// List of order by expressions + order_by: Vec, }, /// aggregate function AggregateUDF { @@ -587,9 +589,15 @@ impl Expr { Expr::ScalarUDF { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), - Expr::WindowFunction { args, .. } => args - .iter() - .try_fold(visitor, |visitor, arg| arg.accept(visitor)), + Expr::WindowFunction { args, order_by, .. } => { + let visitor = args + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + let visitor = order_by + .iter() + .try_fold(visitor, |visitor, arg| arg.accept(visitor))?; + Ok(visitor) + } Expr::AggregateFunction { args, .. } => args .iter() .try_fold(visitor, |visitor, arg| arg.accept(visitor)), @@ -723,9 +731,14 @@ impl Expr { args: rewrite_vec(args, rewriter)?, fun, }, - Expr::WindowFunction { args, fun } => Expr::WindowFunction { + Expr::WindowFunction { + args, + fun, + order_by, + } => Expr::WindowFunction { args: rewrite_vec(args, rewriter)?, fun, + order_by: rewrite_vec(order_by, rewriter)?, }, Expr::AggregateFunction { args, @@ -1388,7 +1401,7 @@ fn create_name(e: &Expr, input_schema: &DFSchema) -> Result { Expr::ScalarUDF { fun, args, .. } => { create_function_name(&fun.name, false, args, input_schema) } - Expr::WindowFunction { fun, args } => { + Expr::WindowFunction { fun, args, .. } => { create_function_name(&fun.to_string(), false, args, input_schema) } Expr::AggregateFunction { diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 5cb94be405e7..fe1dfb6de990 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -92,8 +92,6 @@ pub enum LogicalPlan { // filter_by_expr: Vec, /// Partition by expressions // partition_by_expr: Vec, - /// Order by expressions - // order_by_expr: Vec, /// Window Frame // window_frame: Option, /// The schema description of the window output @@ -306,25 +304,12 @@ impl LogicalPlan { Partitioning::Hash(expr, _) => expr.clone(), _ => vec![], }, - LogicalPlan::Window { - window_expr, - // FIXME implement next - // filter_by_expr, - // FIXME implement next - // partition_by_expr, - // FIXME implement next - // order_by_expr, - .. - } => window_expr.clone(), + LogicalPlan::Window { window_expr, .. } => window_expr.clone(), LogicalPlan::Aggregate { group_expr, aggr_expr, .. - } => { - let mut result = group_expr.clone(); - result.extend(aggr_expr.clone()); - result - } + } => group_expr.iter().chain(aggr_expr.iter()).cloned().collect(), LogicalPlan::Join { on, .. } => { on.iter().flat_map(|(l, r)| vec![col(l), col(r)]).collect() } @@ -698,16 +683,11 @@ impl LogicalPlan { .. } => write!(f, "Filter: {:?}", expr), LogicalPlan::Window { - ref window_expr, - // FIXME implement next - // ref partition_by_expr, - // FIXME implement next - // ref order_by_expr, - .. + ref window_expr, .. } => { write!( f, - "WindowAggr: windowExpr=[{:?}] partitionBy=[], orderBy=[]", + "WindowAggr: windowExpr=[{:?}] partitionBy=[]", window_expr ) } diff --git a/datafusion/src/optimizer/projection_push_down.rs b/datafusion/src/optimizer/projection_push_down.rs index e47832b07f92..f0b364ab9852 100644 --- a/datafusion/src/optimizer/projection_push_down.rs +++ b/datafusion/src/optimizer/projection_push_down.rs @@ -23,6 +23,7 @@ use crate::execution::context::ExecutionProps; use crate::logical_plan::{DFField, DFSchema, DFSchemaRef, LogicalPlan, ToDFSchema}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; +use crate::sql::utils::find_sort_exprs; use arrow::datatypes::Schema; use arrow::error::Result as ArrowResult; use std::{collections::HashSet, sync::Arc}; @@ -197,29 +198,29 @@ fn optimize_plan( schema, window_expr, input, - // FIXME implement next - // filter_by_expr, - // FIXME implement next - // partition_by_expr, - // FIXME implement next - // order_by_expr, - // FIXME implement next - // window_frame, .. } => { // Gather all columns needed for expressions in this Window let mut new_window_expr = Vec::new(); - window_expr.iter().try_for_each(|expr| { - let name = &expr.name(&schema)?; - if required_columns.contains(name) { - new_window_expr.push(expr.clone()); - new_required_columns.insert(name.clone()); - // add to the new set of required columns - utils::expr_to_column_names(expr, &mut new_required_columns) - } else { - Ok(()) - } - })?; + { + window_expr.iter().try_for_each(|expr| { + let name = &expr.name(&schema)?; + if required_columns.contains(name) { + new_window_expr.push(expr.clone()); + new_required_columns.insert(name.clone()); + // add to the new set of required columns + utils::expr_to_column_names(expr, &mut new_required_columns) + } else { + Ok(()) + } + })?; + } + + // for all the retained window expr, find their sort expressions if any, and retain these + utils::exprlist_to_column_names( + &find_sort_exprs(&new_window_expr), + &mut new_required_columns, + )?; let new_schema = DFSchema::new( schema @@ -232,12 +233,6 @@ fn optimize_plan( Ok(LogicalPlan::Window { window_expr: new_window_expr, - // FIXME implement next - // partition_by_expr: partition_by_expr.clone(), - // FIXME implement next - // order_by_expr: order_by_expr.clone(), - // FIXME implement next - // window_frame: window_frame.clone(), input: Arc::new(optimize_plan( optimizer, &input, diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 284ead252ac6..2cb65066feb9 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -36,6 +36,7 @@ use crate::{ const CASE_EXPR_MARKER: &str = "__DATAFUSION_CASE_EXPR__"; const CASE_ELSE_MARKER: &str = "__DATAFUSION_CASE_ELSE__"; +const WINDOW_SORT_MARKER: &str = "__DATAFUSION_WINDOW_SORT__"; /// Recursively walk a list of expression trees, collecting the unique set of column /// names referenced in the expression @@ -190,14 +191,6 @@ pub fn from_plan( }), }, LogicalPlan::Window { - // FIXME implement next - // filter_by_expr, - // FIXME implement next - // partition_by_expr, - // FIXME implement next - // order_by_expr, - // FIXME implement next - // window_frame, window_expr, schema, .. @@ -265,7 +258,13 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::IsNotNull(e) => Ok(vec![e.as_ref().to_owned()]), Expr::ScalarFunction { args, .. } => Ok(args.clone()), Expr::ScalarUDF { args, .. } => Ok(args.clone()), - Expr::WindowFunction { args, .. } => Ok(args.clone()), + Expr::WindowFunction { args, order_by, .. } => { + let mut expr_list: Vec = vec![]; + expr_list.extend(args.clone()); + expr_list.push(lit(WINDOW_SORT_MARKER)); + expr_list.extend(order_by.clone()); + Ok(expr_list) + } Expr::AggregateFunction { args, .. } => Ok(args.clone()), Expr::AggregateUDF { args, .. } => Ok(args.clone()), Expr::Case { @@ -338,10 +337,24 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { fun: fun.clone(), args: expressions.to_vec(), }), - Expr::WindowFunction { fun, .. } => Ok(Expr::WindowFunction { - fun: fun.clone(), - args: expressions.to_vec(), - }), + Expr::WindowFunction { fun, .. } => { + let index = expressions + .iter() + .position(|expr| { + matches!(expr, Expr::Literal(ScalarValue::Utf8(Some(str))) + if str == WINDOW_SORT_MARKER) + }) + .ok_or_else(|| { + DataFusionError::Internal( + "Ill-formed window function expressions".to_owned(), + ) + })?; + Ok(Expr::WindowFunction { + fun: fun.clone(), + args: expressions[..index].to_vec(), + order_by: expressions[index + 1..].to_vec(), + }) + } Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction { fun: fun.clone(), args: expressions.to_vec(), diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 4971a027ef1e..b77850f9d67f 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -746,13 +746,18 @@ impl DefaultPhysicalPlanner { }; match e { - Expr::WindowFunction { fun, args } => { + Expr::WindowFunction { fun, args, .. } => { let args = args .iter() .map(|e| { self.create_physical_expr(e, physical_input_schema, ctx_state) }) .collect::>>()?; + // if !order_by.is_empty() { + // return Err(DataFusionError::NotImplemented( + // "Window function with order by is not yet implemented".to_owned(), + // )); + // } windows::create_window_expr(fun, &args, physical_input_schema, name) } other => Err(DataFusionError::Internal(format!( diff --git a/datafusion/src/sql/mod.rs b/datafusion/src/sql/mod.rs index 456ad4c2e361..cc8b004505fb 100644 --- a/datafusion/src/sql/mod.rs +++ b/datafusion/src/sql/mod.rs @@ -20,4 +20,4 @@ pub mod parser; pub mod planner; -mod utils; +pub(crate) mod utils; diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 63499aa1abe2..3b8acc67ccb2 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -54,8 +54,8 @@ use super::{ parser::DFParser, utils::{ can_columns_satisfy_exprs, expand_wildcard, expr_as_column_expr, extract_aliases, - find_aggregate_exprs, find_column_exprs, find_window_exprs, rebase_expr, - resolve_aliases_to_exprs, + find_aggregate_exprs, find_column_exprs, find_window_exprs, + group_window_expr_by_sort_keys, rebase_expr, resolve_aliases_to_exprs, }, }; @@ -628,7 +628,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let (plan, exprs) = if window_func_exprs.is_empty() { (plan, select_exprs_post_aggr) } else { - self.window(&plan, window_func_exprs, &select_exprs_post_aggr)? + self.window(plan, window_func_exprs, &select_exprs_post_aggr)? }; let plan = if select.distinct { @@ -670,13 +670,28 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { /// Wrap a plan in a window fn window( &self, - input: &LogicalPlan, + input: LogicalPlan, window_exprs: Vec, select_exprs: &[Expr], ) -> Result<(LogicalPlan, Vec)> { - let plan = LogicalPlanBuilder::from(input) - .window(window_exprs.clone())? - .build()?; + let mut plan = input; + let mut groups = group_window_expr_by_sort_keys(&window_exprs)?; + // sort by sort_key len descending, so that more deeply sorted plans gets nested further + // down as children; to further minic the behavior of PostgreSQL, we want stable sort + // and a reverse so that tieing sort keys are reversed in order; note that by this rule + // if there's an empty over, it'll be at the top level + groups.sort_by(|(key_a, _), (key_b, _)| key_a.len().cmp(&key_b.len())); + groups.reverse(); + for (sort_keys, exprs) in groups { + if !sort_keys.is_empty() { + let sort_keys: Vec = sort_keys.to_vec(); + plan = LogicalPlanBuilder::from(&plan).sort(sort_keys)?.build()?; + } + let window_exprs: Vec = exprs.into_iter().cloned().collect(); + plan = LogicalPlanBuilder::from(&plan) + .window(window_exprs)? + .build()?; + } let select_exprs = select_exprs .iter() .map(|expr| rebase_expr(expr, &window_exprs, &plan)) @@ -779,21 +794,24 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { return Ok(plan.clone()); } - let input_schema = plan.schema(); - let order_by_rex: Result> = order_by + let order_by_rex = order_by .iter() - .map(|e| { - Ok(Expr::Sort { - expr: Box::new(self.sql_to_rex(&e.expr, &input_schema)?), - // by default asc - asc: e.asc.unwrap_or(true), - // by default nulls first to be consistent with spark - nulls_first: e.nulls_first.unwrap_or(true), - }) - }) - .collect(); + .map(|e| self.order_by_to_sort_expr(e)) + .into_iter() + .collect::>>()?; - LogicalPlanBuilder::from(&plan).sort(order_by_rex?)?.build() + LogicalPlanBuilder::from(&plan).sort(order_by_rex)?.build() + } + + /// convert sql OrderByExpr to Expr::Sort + fn order_by_to_sort_expr(&self, e: &OrderByExpr) -> Result { + Ok(Expr::Sort { + expr: Box::new(self.sql_expr_to_logical_expr(&e.expr)?), + // by default asc + asc: e.asc.unwrap_or(true), + // by default nulls first to be consistent with spark + nulls_first: e.nulls_first.unwrap_or(true), + }) } /// Validate the schema provides all of the columns referenced in the expressions. @@ -982,7 +1000,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { UnaryOperator::Plus => Ok(self.sql_expr_to_logical_expr(expr)?), UnaryOperator::Minus => { match expr.as_ref() { - // optimization: if it's a number literal, we applly the negative operator + // optimization: if it's a number literal, we apply the negative operator // here directly to calculate the new literal. SQLExpr::Value(Value::Number(n,_)) => match n.parse::() { Ok(n) => Ok(lit(-n)), @@ -1091,10 +1109,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // then, window function if let Some(window) = &function.over { - if window.partition_by.is_empty() - && window.order_by.is_empty() - && window.window_frame.is_none() - { + if window.partition_by.is_empty() && window.window_frame.is_none() { + let order_by = window + .order_by + .iter() + .map(|e| self.order_by_to_sort_expr(e)) + .into_iter() + .collect::>>()?; let fun = window_functions::WindowFunction::from_str(&name); if let Ok(window_functions::WindowFunction::AggregateFunction( aggregate_fun, @@ -1106,6 +1127,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ), args: self .aggregate_fn_to_expr(&aggregate_fun, function)?, + order_by, }); } else if let Ok( window_functions::WindowFunction::BuiltInWindowFunction( @@ -1118,6 +1140,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { window_fun, ), args:self.function_args_to_expr(function)?, + order_by }); } } @@ -2702,7 +2725,7 @@ mod tests { let sql = "SELECT order_id, MAX(order_id) OVER () from orders"; let expected = "\ Projection: #order_id, #MAX(order_id)\ - \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[], orderBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2712,7 +2735,7 @@ mod tests { let sql = "SELECT order_id oid, MAX(order_id) OVER () max_oid from orders"; let expected = "\ Projection: #order_id AS oid, #MAX(order_id) AS max_oid\ - \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[], orderBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#order_id)]] partitionBy=[]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2722,7 +2745,7 @@ mod tests { let sql = "SELECT order_id, MAX(qty * 1.1) OVER () from orders"; let expected = "\ Projection: #order_id, #MAX(qty Multiply Float64(1.1))\ - \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]] partitionBy=[], orderBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty Multiply Float64(1.1))]] partitionBy=[]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2733,7 +2756,7 @@ mod tests { "SELECT order_id, MAX(qty) OVER (), min(qty) over (), aVg(qty) OVER () from orders"; let expected = "\ Projection: #order_id, #MAX(qty), #MIN(qty), #AVG(qty)\ - \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]] partitionBy=[], orderBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty), MIN(#qty), AVG(#qty)]] partitionBy=[]\ \n TableScan: orders projection=None"; quick_test(sql, expected); } @@ -2749,14 +2772,139 @@ mod tests { ); } + /// psql result + /// ``` + /// QUERY PLAN + /// ---------------------------------------------------------------------------------- + /// WindowAgg (cost=137.16..154.66 rows=1000 width=12) + /// -> Sort (cost=137.16..139.66 rows=1000 width=12) + /// Sort Key: order_id + /// -> WindowAgg (cost=69.83..87.33 rows=1000 width=12) + /// -> Sort (cost=69.83..72.33 rows=1000 width=8) + /// Sort Key: order_id DESC + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) + /// ``` + #[test] + fn over_order_by() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty), #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id DESC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + /// psql result + /// ``` + /// QUERY PLAN + /// ----------------------------------------------------------------------------------- + /// WindowAgg (cost=142.16..162.16 rows=1000 width=16) + /// -> Sort (cost=142.16..144.66 rows=1000 width=16) + /// Sort Key: order_id + /// -> WindowAgg (cost=72.33..92.33 rows=1000 width=16) + /// -> Sort (cost=72.33..74.83 rows=1000 width=12) + /// Sort Key: ((order_id + 1)) + /// -> Seq Scan on orders (cost=0.00..22.50 rows=1000 width=12) + /// ``` + #[test] + fn over_order_by_two_sort_keys() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), MIN(qty) OVER (ORDER BY (order_id + 1)) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty), #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id Plus Int64(1) ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + /// psql result + /// ``` + /// QUERY PLAN + /// ---------------------------------------------------------------------------------------- + /// WindowAgg (cost=139.66..172.16 rows=1000 width=24) + /// -> WindowAgg (cost=139.66..159.66 rows=1000 width=16) + /// -> Sort (cost=139.66..142.16 rows=1000 width=12) + /// Sort Key: qty, order_id + /// -> WindowAgg (cost=69.83..89.83 rows=1000 width=12) + /// -> Sort (cost=69.83..72.33 rows=1000 width=8) + /// Sort Key: order_id, qty + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) + /// ``` + #[test] + fn over_order_by_sort_keys_sorting() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\ + \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + /// psql result + /// ``` + /// QUERY PLAN + /// ---------------------------------------------------------------------------------- + /// WindowAgg (cost=69.83..117.33 rows=1000 width=24) + /// -> WindowAgg (cost=69.83..104.83 rows=1000 width=16) + /// -> WindowAgg (cost=69.83..89.83 rows=1000 width=12) + /// -> Sort (cost=69.83..72.33 rows=1000 width=8) + /// Sort Key: order_id, qty + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) + /// ``` + /// + /// FIXME: for now we are not detecting prefix of sorting keys in order to save one sort exec phase #[test] - fn over_order_by_not_supported() { - let sql = "SELECT order_id, MAX(delivered) OVER (order BY order_id) from orders"; - let err = logical_plan(sql).expect_err("query should have failed"); - assert_eq!( - "NotImplemented(\"Unsupported OVER clause (ORDER BY order_id)\")", - format!("{:?}", err) - ); + fn over_order_by_sort_keys_sorting_prefix_compacting() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\ + \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + /// psql result + /// ``` + /// QUERY PLAN + /// ---------------------------------------------------------------------------------------- + /// WindowAgg (cost=139.66..172.16 rows=1000 width=24) + /// -> WindowAgg (cost=139.66..159.66 rows=1000 width=16) + /// -> Sort (cost=139.66..142.16 rows=1000 width=12) + /// Sort Key: order_id, qty + /// -> WindowAgg (cost=69.83..89.83 rows=1000 width=12) + /// -> Sort (cost=69.83..72.33 rows=1000 width=8) + /// Sort Key: qty, order_id + /// -> Seq Scan on orders (cost=0.00..20.00 rows=1000 width=8) + /// ``` + /// + /// FIXME: for now we are not detecting prefix of sorting keys in order to re-arrange with global + /// sort + #[test] + fn over_order_by_sort_keys_sorting_global_order_compacting() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY qty, order_id), SUM(qty) OVER (), MIN(qty) OVER (ORDER BY order_id, qty) from orders ORDER BY order_id"; + let expected = "\ + Sort: #order_id ASC NULLS FIRST\ + \n Projection: #order_id, #MAX(qty), #SUM(qty), #MIN(qty)\ + \n WindowAggr: windowExpr=[[SUM(#qty)]] partitionBy=[]\ + \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n Sort: #qty ASC NULLS FIRST, #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST, #qty ASC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); } #[test] diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 70b9df060839..80a25d04468f 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +//! SQL Utility Functions + use crate::logical_plan::{DFSchema, Expr, LogicalPlan}; use crate::{ error::{DataFusionError, Result}, @@ -46,6 +48,14 @@ pub(crate) fn find_aggregate_exprs(exprs: &[Expr]) -> Vec { }) } +/// Collect all deeply nested `Expr::Sort`. They are returned in order of occurrence +/// (depth first), with duplicates omitted. +pub(crate) fn find_sort_exprs(exprs: &[Expr]) -> Vec { + find_exprs_in_exprs(exprs, &|nested_expr| { + matches!(nested_expr, Expr::Sort { .. }) + }) +} + /// Collect all deeply nested `Expr::WindowFunction`. They are returned in order of occurrence /// (depth first), with duplicates omitted. pub(crate) fn find_window_exprs(exprs: &[Expr]) -> Vec { @@ -225,12 +235,20 @@ where .collect::>>()?, distinct: *distinct, }), - Expr::WindowFunction { fun, args } => Ok(Expr::WindowFunction { + Expr::WindowFunction { + fun, + args, + order_by, + } => Ok(Expr::WindowFunction { fun: fun.clone(), args: args .iter() .map(|e| clone_with_replacement(e, replacement_fn)) - .collect::>>()?, + .collect::>>()?, + order_by: order_by + .iter() + .map(|e| clone_with_replacement(e, replacement_fn)) + .collect::>>()?, }), Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF { fun: fun.clone(), @@ -389,3 +407,192 @@ pub(crate) fn resolve_aliases_to_exprs( _ => Ok(None), }) } + +/// group a slice of window expression expr by their order by expressions +pub(crate) fn group_window_expr_by_sort_keys( + window_expr: &[Expr], +) -> Result)>> { + let mut result = vec![]; + window_expr.iter().try_for_each(|expr| match expr { + Expr::WindowFunction { order_by, .. } => { + if let Some((_, values)) = result.iter_mut().find( + |group: &&mut (&[Expr], Vec<&Expr>)| matches!(group, (key, _) if key == order_by), + ) { + values.push(expr); + } else { + result.push((order_by, vec![expr])) + } + Ok(()) + } + other => Err(DataFusionError::Internal(format!( + "Impossibly got non-window expr {:?}", + other, + ))), + })?; + Ok(result) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::logical_plan::col; + use crate::physical_plan::aggregates::AggregateFunction; + use crate::physical_plan::window_functions::WindowFunction; + + #[test] + fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> { + let result = group_window_expr_by_sort_keys(&[])?; + let expected: Vec<(&[Expr], Vec<&Expr>)> = vec![]; + assert_eq!(expected, result); + Ok(()) + } + + #[test] + fn test_group_window_expr_by_sort_keys_empty_window() -> Result<()> { + let max1 = Expr::WindowFunction { + fun: WindowFunction::AggregateFunction(AggregateFunction::Max), + args: vec![col("name")], + order_by: vec![], + }; + let max2 = Expr::WindowFunction { + fun: WindowFunction::AggregateFunction(AggregateFunction::Max), + args: vec![col("name")], + order_by: vec![], + }; + let min3 = Expr::WindowFunction { + fun: WindowFunction::AggregateFunction(AggregateFunction::Min), + args: vec![col("name")], + order_by: vec![], + }; + let sum4 = Expr::WindowFunction { + fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), + args: vec![col("age")], + order_by: vec![], + }; + // FIXME use as_ref + let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; + let result = group_window_expr_by_sort_keys(exprs)?; + let key = &[]; + let expected: Vec<(&[Expr], Vec<&Expr>)> = + vec![(key, vec![&max1, &max2, &min3, &sum4])]; + assert_eq!(expected, result); + Ok(()) + } + + #[test] + fn test_group_window_expr_by_sort_keys() -> Result<()> { + let age_asc = Expr::Sort { + expr: Box::new(col("age")), + asc: true, + nulls_first: true, + }; + let name_desc = Expr::Sort { + expr: Box::new(col("name")), + asc: false, + nulls_first: true, + }; + let created_at_desc = Expr::Sort { + expr: Box::new(col("created_at")), + asc: false, + nulls_first: true, + }; + let max1 = Expr::WindowFunction { + fun: WindowFunction::AggregateFunction(AggregateFunction::Max), + args: vec![col("name")], + order_by: vec![age_asc.clone(), name_desc.clone()], + }; + let max2 = Expr::WindowFunction { + fun: WindowFunction::AggregateFunction(AggregateFunction::Max), + args: vec![col("name")], + order_by: vec![], + }; + let min3 = Expr::WindowFunction { + fun: WindowFunction::AggregateFunction(AggregateFunction::Min), + args: vec![col("name")], + order_by: vec![age_asc.clone(), name_desc.clone()], + }; + let sum4 = Expr::WindowFunction { + fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), + args: vec![col("age")], + order_by: vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], + }; + // FIXME use as_ref + let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; + let result = group_window_expr_by_sort_keys(exprs)?; + + let key1 = &[age_asc.clone(), name_desc.clone()]; + let key2 = &[]; + let key3 = &[name_desc, age_asc, created_at_desc]; + + let expected: Vec<(&[Expr], Vec<&Expr>)> = vec![ + (key1, vec![&max1, &min3]), + (key2, vec![&max2]), + (key3, vec![&sum4]), + ]; + assert_eq!(expected, result); + Ok(()) + } + + #[test] + fn test_find_sort_exprs() -> Result<()> { + let exprs = &[ + Expr::WindowFunction { + fun: WindowFunction::AggregateFunction(AggregateFunction::Max), + args: vec![col("name")], + order_by: vec![ + Expr::Sort { + expr: Box::new(col("age")), + asc: true, + nulls_first: true, + }, + Expr::Sort { + expr: Box::new(col("name")), + asc: false, + nulls_first: true, + }, + ], + }, + Expr::WindowFunction { + fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), + args: vec![col("age")], + order_by: vec![ + Expr::Sort { + expr: Box::new(col("name")), + asc: false, + nulls_first: true, + }, + Expr::Sort { + expr: Box::new(col("age")), + asc: true, + nulls_first: true, + }, + Expr::Sort { + expr: Box::new(col("created_at")), + asc: false, + nulls_first: true, + }, + ], + }, + ]; + let expected = vec![ + Expr::Sort { + expr: Box::new(col("age")), + asc: true, + nulls_first: true, + }, + Expr::Sort { + expr: Box::new(col("name")), + asc: false, + nulls_first: true, + }, + Expr::Sort { + expr: Box::new(col("created_at")), + asc: false, + nulls_first: true, + }, + ]; + let result = find_sort_exprs(exprs); + assert_eq!(expected, result); + Ok(()) + } +}