From 4e1290a600233661e28952a42d39563c653c286e Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Mon, 30 Dec 2024 18:31:04 -0600 Subject: [PATCH] Customize window frame support for dialect (#70) * Customize window frame support for dialect * fix: ignore frame only when frame implies no frame * Add comments. move the window frame determine logic to dialect method * Update datafusion/sql/src/unparser/dialect.rs * Update test case * fix --------- Co-authored-by: Phillip LeBlanc --- datafusion/sql/src/unparser/dialect.rs | 38 ++++++++++- datafusion/sql/src/unparser/expr.rs | 91 ++++++++++++++++++++------ 2 files changed, 107 insertions(+), 22 deletions(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index c8f8c747160d..3015676ca15f 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -21,7 +21,9 @@ use arrow_schema::TimeUnit; use datafusion_expr::Expr; use regex::Regex; use sqlparser::{ - ast::{self, BinaryOperator, Function, Ident, ObjectName, TimezoneInfo}, + ast::{ + self, BinaryOperator, Function, Ident, ObjectName, TimezoneInfo, WindowFrameBound, + }, keywords::ALL_KEYWORDS, }; @@ -158,6 +160,18 @@ pub trait Dialect: Send + Sync { ) -> Result> { Ok(None) } + + /// Allows the dialect to choose to omit window frame in unparsing + /// based on function name and window frame bound + /// Returns false if specific function name / window frame bound indicates no window frame is needed in unparsing + fn window_func_support_window_frame( + &self, + _func_name: &str, + _start_bound: &WindowFrameBound, + _end_bound: &WindowFrameBound, + ) -> bool { + true + } } /// `IntervalStyle` to use for unparsing @@ -483,6 +497,7 @@ pub struct CustomDialect { supports_column_alias_in_table_alias: bool, requires_derived_table_alias: bool, division_operator: BinaryOperator, + window_func_support_window_frame: bool, } impl Default for CustomDialect { @@ -508,6 +523,7 @@ impl Default for CustomDialect { supports_column_alias_in_table_alias: true, requires_derived_table_alias: false, division_operator: BinaryOperator::Divide, + window_func_support_window_frame: true, } } } @@ -616,6 +632,15 @@ impl Dialect for CustomDialect { fn division_operator(&self) -> BinaryOperator { self.division_operator.clone() } + + fn window_func_support_window_frame( + &self, + _func_name: &str, + _start_bound: &WindowFrameBound, + _end_bound: &WindowFrameBound, + ) -> bool { + self.window_func_support_window_frame + } } /// `CustomDialectBuilder` to build `CustomDialect` using builder pattern @@ -650,6 +675,7 @@ pub struct CustomDialectBuilder { supports_column_alias_in_table_alias: bool, requires_derived_table_alias: bool, division_operator: BinaryOperator, + window_func_support_window_frame: bool, } impl Default for CustomDialectBuilder { @@ -681,6 +707,7 @@ impl CustomDialectBuilder { supports_column_alias_in_table_alias: true, requires_derived_table_alias: false, division_operator: BinaryOperator::Divide, + window_func_support_window_frame: true, } } @@ -704,6 +731,7 @@ impl CustomDialectBuilder { .supports_column_alias_in_table_alias, requires_derived_table_alias: self.requires_derived_table_alias, division_operator: self.division_operator, + window_func_support_window_frame: self.window_func_support_window_frame, } } @@ -825,4 +853,12 @@ impl CustomDialectBuilder { self.division_operator = division_operator; self } + + pub fn with_window_func_support_window_frame( + mut self, + window_func_support_window_frame: bool, + ) -> Self { + self.window_func_support_window_frame = window_func_support_window_frame; + self + } } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 4bc38b104d34..59498f5aed29 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -215,6 +215,21 @@ impl Unparser<'_> { let start_bound = self.convert_bound(&window_frame.start_bound)?; let end_bound = self.convert_bound(&window_frame.end_bound)?; + + let window_frame = if self.dialect.window_func_support_window_frame( + func_name, + &start_bound, + &end_bound, + ) { + Some(ast::WindowFrame { + units, + start_bound, + end_bound: Some(end_bound), + }) + } else { + None + }; + let over = Some(ast::WindowType::WindowSpec(ast::WindowSpec { window_name: None, partition_by: partition_by @@ -222,11 +237,7 @@ impl Unparser<'_> { .map(|e| self.expr_to_sql_inner(e)) .collect::>>()?, order_by, - window_frame: Some(ast::WindowFrame { - units, - start_bound, - end_bound: Option::from(end_bound), - }), + window_frame, })); Ok(ast::Expr::Function(Function { @@ -1485,6 +1496,7 @@ mod tests { use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt}; use datafusion_functions_aggregate::count::count_udaf; use datafusion_functions_aggregate::expr_fn::sum; + use datafusion_functions_window::rank::rank_udwf; use datafusion_functions_window::row_number::row_number_udwf; use crate::unparser::dialect::{ @@ -2196,22 +2208,26 @@ mod tests { #[test] fn test_cast_value_to_binary_expr() { - let tests = [( - Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "blah".to_string(), - )))), - data_type: DataType::Binary, - }), - "'blah'", - Expr::Cast(Cast { - expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( - "blah".to_string(), - )))), - data_type: DataType::BinaryView, - }), - "'blah'", - )]; + let tests = [ + ( + Expr::Cast(Cast { + expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( + "blah".to_string(), + )))), + data_type: DataType::Binary, + }), + "'blah'", + ), + ( + Expr::Cast(Cast { + expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some( + "blah".to_string(), + )))), + data_type: DataType::BinaryView, + }), + "'blah'", + ), + ]; for (value, expected) in tests { let dialect = CustomDialectBuilder::new().build(); let unparser = Unparser::new(&dialect); @@ -2512,4 +2528,37 @@ mod tests { } Ok(()) } + + #[test] + fn test_window_func_support_window_frame() -> Result<()> { + let default_dialect: Arc = + Arc::new(CustomDialectBuilder::new().build()); + + let test_dialect: Arc = Arc::new( + CustomDialectBuilder::new() + .with_window_func_support_window_frame(false) + .build(), + ); + + for (dialect, expected) in [ + ( + default_dialect, + "rank() OVER (ORDER BY a ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)", + ), + (test_dialect, "rank() OVER (ORDER BY a ASC NULLS FIRST)"), + ] { + let unparser = Unparser::new(dialect.as_ref()); + let func = WindowFunctionDefinition::WindowUDF(rank_udwf()); + let mut window_func = WindowFunction::new(func, vec![]); + window_func.order_by = vec![Sort::new(col("a"), true, true)]; + let expr = Expr::WindowFunction(window_func); + let ast = unparser.expr_to_sql(&expr)?; + + let actual = format!("{ast}"); + let expected = format!("{expected}"); + + assert_eq!(actual, expected); + } + Ok(()) + } }