From 023509db1d5175ae266e35ff9a28b17ed6d2fc66 Mon Sep 17 00:00:00 2001 From: Jiayu Liu Date: Thu, 3 Jun 2021 18:17:55 +0800 Subject: [PATCH] closing up type checks --- ballista/rust/core/Cargo.toml | 2 +- ballista/rust/core/proto/ballista.proto | 6 +- .../core/src/serde/logical_plan/from_proto.rs | 35 +-- .../core/src/serde/logical_plan/to_proto.rs | 46 ++-- .../src/serde/physical_plan/from_proto.rs | 4 + datafusion/src/logical_plan/expr.rs | 16 +- datafusion/src/optimizer/utils.rs | 5 +- datafusion/src/physical_plan/mod.rs | 1 + datafusion/src/physical_plan/planner.rs | 9 +- datafusion/src/physical_plan/window_frames.rs | 200 ++++++++++++++++++ datafusion/src/sql/planner.rs | 51 ++++- datafusion/src/sql/utils.rs | 12 ++ 12 files changed, 340 insertions(+), 47 deletions(-) create mode 100644 datafusion/src/physical_plan/window_frames.rs diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml index 99822cfe2aee..1f23a2a42e2a 100644 --- a/ballista/rust/core/Cargo.toml +++ b/ballista/rust/core/Cargo.toml @@ -35,7 +35,7 @@ futures = "0.3" log = "0.4" prost = "0.7" serde = {version = "1", features = ["derive"]} -sqlparser = "0.8" +sqlparser = "0.9.0" tokio = "1.0" tonic = "0.4" uuid = { version = "0.8", features = ["v4"] } diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index d21cbf694b9d..a2b93bb10abb 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -177,9 +177,9 @@ message WindowExprNode { // repeated LogicalExprNode partition_by = 5; repeated LogicalExprNode order_by = 6; // repeated LogicalExprNode filter = 7; - // oneof window_frame { - // WindowFrame frame = 8; - // } + oneof window_frame { + WindowFrame frame = 8; + } } message BetweenNode { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 522d60cb8a05..8812247776b5 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -83,20 +83,6 @@ impl TryInto for &protobuf::LogicalPlanNode { .iter() .map(|expr| expr.try_into()) .collect::, _>>()?; - - // let partition_by_expr = window - // .partition_by_expr - // .iter() - // .map(|expr| expr.try_into()) - // .collect::, _>>()?; - // let order_by_expr = window - // .order_by_expr - // .iter() - // .map(|expr| expr.try_into()) - // .collect::, _>>()?; - // // FIXME: add filter by expr - // // FIXME: parse the window_frame data - // let window_frame = None; LogicalPlanBuilder::from(&input) .window(window_expr)? .build() @@ -928,6 +914,15 @@ impl TryInto for &protobuf::LogicalExprNode { .map(|e| e.try_into()) .into_iter() .collect::, _>>()?; + let window_frame: Option = expr + .window_frame + .as_ref() + .map::, _>(|e| match e { + window_expr_node::WindowFrame::Frame(frame) => { + frame.clone().try_into() + } + }) + .transpose()?; match window_function { window_expr_node::WindowFunction::AggrFunction(i) => { let aggr_function = protobuf::AggregateFunction::from_i32(*i) @@ -944,6 +939,7 @@ impl TryInto for &protobuf::LogicalExprNode { ), args: vec![parse_required_expr(&expr.expr)?], order_by, + window_frame, }) } window_expr_node::WindowFunction::BuiltInFunction(i) => { @@ -963,6 +959,7 @@ impl TryInto for &protobuf::LogicalExprNode { ), args: vec![parse_required_expr(&expr.expr)?], order_by, + window_frame, }) } } @@ -1332,8 +1329,14 @@ impl TryFrom for WindowFrame { ) })? .try_into()?; - // FIXME parse end bound - let end_bound = None; + let end_bound = window + .end_bound + .map(|end_bound| match end_bound { + protobuf::window_frame::EndBound::Bound(end_bound) => { + end_bound.try_into() + } + }) + .transpose()?; Ok(WindowFrame { units, start_bound, diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index 088e93120e4f..1a5bcc94a3a0 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -1006,6 +1006,7 @@ impl TryInto for &Expr { ref fun, ref args, ref order_by, + ref window_frame, .. } => { let window_function = match fun { @@ -1025,10 +1026,15 @@ impl TryInto for &Expr { .iter() .map(|e| e.try_into()) .collect::, _>>()?; + let window_frame = window_frame.as_ref().map(|window_frame| { + let window_frame: protobuf::WindowFrame = window_frame.clone().into(); + protobuf::window_expr_node::WindowFrame::Frame(window_frame) + }); let window_expr = Box::new(protobuf::WindowExprNode { expr: Some(Box::new(arg.try_into()?)), window_function: Some(window_function), order_by, + window_frame, }); Ok(protobuf::LogicalExprNode { expr_type: Some(ExprType::WindowExpr(window_expr)), @@ -1255,23 +1261,35 @@ impl From for protobuf::WindowFrameUnits { } } -impl TryFrom for protobuf::WindowFrameBound { - type Error = BallistaError; - - fn try_from(_bound: WindowFrameBound) -> Result { - Err(BallistaError::NotImplemented( - "WindowFrameBound => protobuf::WindowFrameBound".to_owned(), - )) +impl From for protobuf::WindowFrameBound { + fn from(bound: WindowFrameBound) -> Self { + match bound { + WindowFrameBound::CurrentRow => protobuf::WindowFrameBound { + window_frame_bound_type: protobuf::WindowFrameBoundType::CurrentRow + .into(), + bound_value: None, + }, + WindowFrameBound::Preceding(v) => protobuf::WindowFrameBound { + window_frame_bound_type: protobuf::WindowFrameBoundType::Preceding.into(), + bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value), + }, + WindowFrameBound::Following(v) => protobuf::WindowFrameBound { + window_frame_bound_type: protobuf::WindowFrameBoundType::Following.into(), + bound_value: v.map(protobuf::window_frame_bound::BoundValue::Value), + }, + } } } -impl TryFrom for protobuf::WindowFrame { - type Error = BallistaError; - - fn try_from(_window: WindowFrame) -> Result { - Err(BallistaError::NotImplemented( - "WindowFrame => protobuf::WindowFrame".to_owned(), - )) +impl From for protobuf::WindowFrame { + fn from(window: WindowFrame) -> Self { + protobuf::WindowFrame { + window_frame_units: protobuf::WindowFrameUnits::from(window.units).into(), + start_bound: Some(window.start_bound.into()), + end_bound: window.end_bound.map(|end_bound| { + protobuf::window_frame::EndBound::Bound(end_bound.into()) + }), + } } } diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index c19739a6b061..584e2e251bc1 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -237,6 +237,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { fun, args, order_by, + window_frame, } => { let arg = df_planner .create_physical_expr( @@ -250,6 +251,9 @@ impl TryInto> for &protobuf::PhysicalPlanNode { if !order_by.is_empty() { return Err(BallistaError::NotImplemented("Window function with order by is not yet implemented".to_owned())); } + if window_frame.is_some() { + return Err(BallistaError::NotImplemented("Window function with window frame is not yet implemented".to_owned())); + } let window_expr = create_window_expr( &fun, &[arg], diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs index 5103d5dc5051..8b0a8a537305 100644 --- a/datafusion/src/logical_plan/expr.rs +++ b/datafusion/src/logical_plan/expr.rs @@ -19,13 +19,6 @@ //! such as `col = 5` or `SUM(col)`. See examples on the [`Expr`] struct. pub use super::Operator; - -use std::fmt; -use std::sync::Arc; - -use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; -use arrow::{compute::can_cast_types, datatypes::DataType}; - use crate::error::{DataFusionError, Result}; use crate::logical_plan::{DFField, DFSchema}; use crate::physical_plan::{ @@ -33,8 +26,13 @@ use crate::physical_plan::{ window_functions, }; use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue}; +use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; +use arrow::{compute::can_cast_types, datatypes::DataType}; use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use sqlparser::ast::WindowFrame; use std::collections::HashSet; +use std::fmt; +use std::sync::Arc; /// `Expr` is a central struct of DataFusion's query API, and /// represent logical expressions such as `A + 1`, or `CAST(c1 AS @@ -199,6 +197,8 @@ pub enum Expr { args: Vec, /// List of order by expressions order_by: Vec, + /// Window frame + window_frame: Option, }, /// aggregate function AggregateUDF { @@ -735,10 +735,12 @@ impl Expr { args, fun, order_by, + window_frame, } => Expr::WindowFunction { args: rewrite_vec(args, rewriter)?, fun, order_by: rewrite_vec(order_by, rewriter)?, + window_frame, }, Expr::AggregateFunction { args, diff --git a/datafusion/src/optimizer/utils.rs b/datafusion/src/optimizer/utils.rs index 2cb65066feb9..d1e2f73017de 100644 --- a/datafusion/src/optimizer/utils.rs +++ b/datafusion/src/optimizer/utils.rs @@ -337,7 +337,9 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { fun: fun.clone(), args: expressions.to_vec(), }), - Expr::WindowFunction { fun, .. } => { + Expr::WindowFunction { + fun, window_frame, .. + } => { let index = expressions .iter() .position(|expr| { @@ -353,6 +355,7 @@ pub fn rewrite_expression(expr: &Expr, expressions: &[Expr]) -> Result { fun: fun.clone(), args: expressions[..index].to_vec(), order_by: expressions[index + 1..].to_vec(), + window_frame: window_frame.clone(), }) } Expr::AggregateFunction { fun, distinct, .. } => Ok(Expr::AggregateFunction { diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index af6969c43cbd..490e02875c42 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -617,5 +617,6 @@ pub mod udf; #[cfg(feature = "unicode_expressions")] pub mod unicode_expressions; pub mod union; +pub mod window_frames; pub mod window_functions; pub mod windows; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index b77850f9d67f..356f9b314fd9 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -39,6 +39,7 @@ use crate::physical_plan::projection::ProjectionExec; use crate::physical_plan::repartition::RepartitionExec; use crate::physical_plan::sort::SortExec; use crate::physical_plan::udf; +use crate::physical_plan::window_frames; use crate::physical_plan::windows::WindowAggExec; use crate::physical_plan::{hash_utils, Partitioning}; use crate::physical_plan::{ @@ -746,7 +747,12 @@ impl DefaultPhysicalPlanner { }; match e { - Expr::WindowFunction { fun, args, .. } => { + Expr::WindowFunction { + fun, + args, + window_frame, + .. + } => { let args = args .iter() .map(|e| { @@ -758,6 +764,7 @@ impl DefaultPhysicalPlanner { // "Window function with order by is not yet implemented".to_owned(), // )); // } + let _window_frame = window_frames::validate_window_frame(window_frame)?; windows::create_window_expr(fun, &args, physical_input_schema, name) } other => Err(DataFusionError::Internal(format!( diff --git a/datafusion/src/physical_plan/window_frames.rs b/datafusion/src/physical_plan/window_frames.rs new file mode 100644 index 000000000000..7262ebf2681e --- /dev/null +++ b/datafusion/src/physical_plan/window_frames.rs @@ -0,0 +1,200 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Window frame +//! +//! The frame-spec determines which output rows are read by an aggregate window function. The frame-spec consists of four parts: +//! - A frame type - either ROWS, RANGE or GROUPS, +//! - A starting frame boundary, +//! - An ending frame boundary, +//! - An EXCLUDE clause. + +use crate::error::{DataFusionError, Result}; +use sqlparser::ast::{WindowFrame, WindowFrameBound, WindowFrameUnits}; + +const DEFAULT_WINDOW_FRAME: WindowFrame = WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(None), + end_bound: Some(WindowFrameBound::CurrentRow), +}; + +fn get_bound_rank(bound: &WindowFrameBound) -> (u8, u64) { + match bound { + WindowFrameBound::Preceding(None) => (0, 0), + WindowFrameBound::Following(None) => (4, 0), + WindowFrameBound::Preceding(Some(0)) + | WindowFrameBound::CurrentRow + | WindowFrameBound::Following(Some(0)) => (2, 0), + WindowFrameBound::Preceding(Some(v)) => (1, u64::MAX - *v), + WindowFrameBound::Following(Some(v)) => (3, *v), + } +} + +/// Validate a window frame if present, otherwise return the default window frame. +pub(crate) fn validate_window_frame( + window_frame: &Option, +) -> Result<&WindowFrame> { + let window_frame: &WindowFrame = + window_frame.as_ref().unwrap_or(&DEFAULT_WINDOW_FRAME); + let start_bound = &window_frame.start_bound; + let end_bound = window_frame + .end_bound + .as_ref() + .unwrap_or(&WindowFrameBound::CurrentRow); + + if let WindowFrameBound::Following(None) = *start_bound { + Err(DataFusionError::Execution( + "Invalid window frame: start bound cannot be unbounded following".to_owned(), + )) + } else if let WindowFrameBound::Preceding(None) = *end_bound { + Err(DataFusionError::Execution( + "Invalid window frame: end bound cannot be unbounded preceding".to_owned(), + )) + } else if get_bound_rank(start_bound) > get_bound_rank(end_bound) { + Err(DataFusionError::Execution(format!( + "Invalid window frame: start bound ({}) cannot be larger than end bound ({})", + start_bound, end_bound + ))) + } else { + Ok(window_frame) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_validate_window_frame() -> Result<()> { + let default_value = validate_window_frame(&None)?; + assert_eq!(default_value, &DEFAULT_WINDOW_FRAME); + + let window_frame = Some(WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Following(None), + end_bound: None, + }); + let result = validate_window_frame(&window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned() + ); + + let window_frame = Some(WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(None), + end_bound: Some(WindowFrameBound::Preceding(None)), + }); + let result = validate_window_frame(&window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned() + ); + let window_frame = Some(WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(Some(1)), + end_bound: Some(WindowFrameBound::Preceding(Some(2))), + }); + let result = validate_window_frame(&window_frame); + assert_eq!( + result.err().unwrap().to_string(), + "Execution error: Invalid window frame: start bound (1 PRECEDING) cannot be larger than end bound (2 PRECEDING)".to_owned() + ); + Ok(()) + } + + #[test] + fn test_get_bound_rank_eq() { + assert_eq!( + get_bound_rank(&WindowFrameBound::CurrentRow), + get_bound_rank(&WindowFrameBound::CurrentRow) + ); + assert_eq!( + get_bound_rank(&WindowFrameBound::Preceding(Some(0))), + get_bound_rank(&WindowFrameBound::CurrentRow) + ); + assert_eq!( + get_bound_rank(&WindowFrameBound::CurrentRow), + get_bound_rank(&WindowFrameBound::Following(Some(0))) + ); + assert_eq!( + get_bound_rank(&WindowFrameBound::Following(Some(2))), + get_bound_rank(&WindowFrameBound::Following(Some(2))) + ); + assert_eq!( + get_bound_rank(&WindowFrameBound::Following(None)), + get_bound_rank(&WindowFrameBound::Following(None)) + ); + assert_eq!( + get_bound_rank(&WindowFrameBound::Preceding(Some(2))), + get_bound_rank(&WindowFrameBound::Preceding(Some(2))) + ); + assert_eq!( + get_bound_rank(&WindowFrameBound::Preceding(None)), + get_bound_rank(&WindowFrameBound::Preceding(None)) + ); + } + + #[test] + fn test_get_bound_rank_cmp() { + assert!( + get_bound_rank(&WindowFrameBound::Preceding(Some(1))) + < get_bound_rank(&WindowFrameBound::CurrentRow) + ); + // ! yes this is correct! + assert!( + get_bound_rank(&WindowFrameBound::Preceding(Some(2))) + < get_bound_rank(&WindowFrameBound::Preceding(Some(1))) + ); + assert!( + get_bound_rank(&WindowFrameBound::Preceding(Some(u64::MAX))) + < get_bound_rank(&WindowFrameBound::Preceding(Some(u64::MAX - 1))) + ); + assert!( + get_bound_rank(&WindowFrameBound::Preceding(None)) + < get_bound_rank(&WindowFrameBound::Preceding(Some(1000000))) + ); + assert!( + get_bound_rank(&WindowFrameBound::Preceding(None)) + < get_bound_rank(&WindowFrameBound::Preceding(Some(u64::MAX))) + ); + assert!( + get_bound_rank(&WindowFrameBound::Preceding(None)) + < get_bound_rank(&WindowFrameBound::Following(Some(0))) + ); + assert!( + get_bound_rank(&WindowFrameBound::Preceding(Some(1))) + < get_bound_rank(&WindowFrameBound::Following(Some(1))) + ); + assert!( + get_bound_rank(&WindowFrameBound::CurrentRow) + < get_bound_rank(&WindowFrameBound::Following(Some(1))) + ); + assert!( + get_bound_rank(&WindowFrameBound::Following(Some(1))) + < get_bound_rank(&WindowFrameBound::Following(Some(2))) + ); + assert!( + get_bound_rank(&WindowFrameBound::Following(Some(2))) + < get_bound_rank(&WindowFrameBound::Following(None)) + ); + assert!( + get_bound_rank(&WindowFrameBound::Following(Some(u64::MAX))) + < get_bound_rank(&WindowFrameBound::Following(None)) + ); + } +} diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 3b8acc67ccb2..e8a94110a43d 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -43,7 +43,7 @@ use sqlparser::ast::{ SetExpr, SetOperator, ShowStatementFilter, TableFactor, TableWithJoins, UnaryOperator, Value, }; -use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; +use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption, WindowFrame}; use sqlparser::ast::{OrderByExpr, Statement}; use sqlparser::parser::ParserError::ParserError; use std::str::FromStr; @@ -1109,13 +1109,15 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // then, window function if let Some(window) = &function.over { - if window.partition_by.is_empty() && window.window_frame.is_none() { + if window.partition_by.is_empty() { let order_by = window .order_by .iter() .map(|e| self.order_by_to_sort_expr(e)) .into_iter() .collect::>>()?; + let window_frame: Option = + window.window_frame.clone(); let fun = window_functions::WindowFunction::from_str(&name); if let Ok(window_functions::WindowFunction::AggregateFunction( aggregate_fun, @@ -1128,6 +1130,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { args: self .aggregate_fn_to_expr(&aggregate_fun, function)?, order_by, + window_frame, }); } else if let Ok( window_functions::WindowFunction::BuiltInWindowFunction( @@ -1139,8 +1142,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { fun: window_functions::WindowFunction::BuiltInWindowFunction( window_fun, ), - args:self.function_args_to_expr(function)?, - order_by + args: self.function_args_to_expr(function)?, + order_by, + window_frame, }); } } @@ -2797,6 +2801,45 @@ mod tests { quick_test(sql, expected); } + #[test] + fn over_order_by_with_window_frame_double_end() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE BETWEEN 3 PRECEDING and 3 FOLLOWING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty), #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id DESC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + #[test] + fn over_order_by_with_window_frame_single_end() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id RANGE 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty), #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id DESC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + + #[test] + fn over_order_by_with_window_frame_single_end_groups() { + let sql = "SELECT order_id, MAX(qty) OVER (ORDER BY order_id GROUPS 3 PRECEDING), MIN(qty) OVER (ORDER BY order_id DESC) from orders"; + let expected = "\ + Projection: #order_id, #MAX(qty), #MIN(qty)\ + \n WindowAggr: windowExpr=[[MAX(#qty)]] partitionBy=[]\ + \n Sort: #order_id ASC NULLS FIRST\ + \n WindowAggr: windowExpr=[[MIN(#qty)]] partitionBy=[]\ + \n Sort: #order_id DESC NULLS FIRST\ + \n TableScan: orders projection=None"; + quick_test(sql, expected); + } + /// psql result /// ``` /// QUERY PLAN diff --git a/datafusion/src/sql/utils.rs b/datafusion/src/sql/utils.rs index 80a25d04468f..043fe7546c48 100644 --- a/datafusion/src/sql/utils.rs +++ b/datafusion/src/sql/utils.rs @@ -239,6 +239,7 @@ where fun, args, order_by, + window_frame, } => Ok(Expr::WindowFunction { fun: fun.clone(), args: args @@ -249,6 +250,7 @@ where .iter() .map(|e| clone_with_replacement(e, replacement_fn)) .collect::>>()?, + window_frame: window_frame.clone(), }), Expr::AggregateUDF { fun, args } => Ok(Expr::AggregateUDF { fun: fun.clone(), @@ -453,21 +455,25 @@ mod tests { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], order_by: vec![], + window_frame: None, }; let max2 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], order_by: vec![], + window_frame: None, }; let min3 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Min), args: vec![col("name")], order_by: vec![], + window_frame: None, }; let sum4 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), args: vec![col("age")], order_by: vec![], + window_frame: None, }; // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; @@ -500,21 +506,25 @@ mod tests { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], order_by: vec![age_asc.clone(), name_desc.clone()], + window_frame: None, }; let max2 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Max), args: vec![col("name")], order_by: vec![], + window_frame: None, }; let min3 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Min), args: vec![col("name")], order_by: vec![age_asc.clone(), name_desc.clone()], + window_frame: None, }; let sum4 = Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), args: vec![col("age")], order_by: vec![name_desc.clone(), age_asc.clone(), created_at_desc.clone()], + window_frame: None, }; // FIXME use as_ref let exprs = &[max1.clone(), max2.clone(), min3.clone(), sum4.clone()]; @@ -551,6 +561,7 @@ mod tests { nulls_first: true, }, ], + window_frame: None, }, Expr::WindowFunction { fun: WindowFunction::AggregateFunction(AggregateFunction::Sum), @@ -572,6 +583,7 @@ mod tests { nulls_first: true, }, ], + window_frame: None, }, ]; let expected = vec![