Skip to content

Commit

Permalink
Customize window frame support for dialect (#70)
Browse files Browse the repository at this point in the history
* Customize window frame support for dialect

* fix: ignore frame only when frame implies no frame

* Add comments. move the window frame determine logic to dialect method

* Update datafusion/sql/src/unparser/dialect.rs

* Update test case

* fix

---------

Co-authored-by: Phillip LeBlanc <phillip@leblanc.tech>
  • Loading branch information
Sevenannn and phillipleblanc authored Dec 31, 2024
1 parent ffc3ca6 commit 4e1290a
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 22 deletions.
38 changes: 37 additions & 1 deletion datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ use arrow_schema::TimeUnit;
use datafusion_expr::Expr;
use regex::Regex;
use sqlparser::{
ast::{self, BinaryOperator, Function, Ident, ObjectName, TimezoneInfo},
ast::{
self, BinaryOperator, Function, Ident, ObjectName, TimezoneInfo, WindowFrameBound,
},
keywords::ALL_KEYWORDS,
};

Expand Down Expand Up @@ -158,6 +160,18 @@ pub trait Dialect: Send + Sync {
) -> Result<Option<ast::Expr>> {
Ok(None)
}

/// Allows the dialect to choose to omit window frame in unparsing
/// based on function name and window frame bound
/// Returns false if specific function name / window frame bound indicates no window frame is needed in unparsing
fn window_func_support_window_frame(
&self,
_func_name: &str,
_start_bound: &WindowFrameBound,
_end_bound: &WindowFrameBound,
) -> bool {
true
}
}

/// `IntervalStyle` to use for unparsing
Expand Down Expand Up @@ -483,6 +497,7 @@ pub struct CustomDialect {
supports_column_alias_in_table_alias: bool,
requires_derived_table_alias: bool,
division_operator: BinaryOperator,
window_func_support_window_frame: bool,
}

impl Default for CustomDialect {
Expand All @@ -508,6 +523,7 @@ impl Default for CustomDialect {
supports_column_alias_in_table_alias: true,
requires_derived_table_alias: false,
division_operator: BinaryOperator::Divide,
window_func_support_window_frame: true,
}
}
}
Expand Down Expand Up @@ -616,6 +632,15 @@ impl Dialect for CustomDialect {
fn division_operator(&self) -> BinaryOperator {
self.division_operator.clone()
}

fn window_func_support_window_frame(
&self,
_func_name: &str,
_start_bound: &WindowFrameBound,
_end_bound: &WindowFrameBound,
) -> bool {
self.window_func_support_window_frame
}
}

/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern
Expand Down Expand Up @@ -650,6 +675,7 @@ pub struct CustomDialectBuilder {
supports_column_alias_in_table_alias: bool,
requires_derived_table_alias: bool,
division_operator: BinaryOperator,
window_func_support_window_frame: bool,
}

impl Default for CustomDialectBuilder {
Expand Down Expand Up @@ -681,6 +707,7 @@ impl CustomDialectBuilder {
supports_column_alias_in_table_alias: true,
requires_derived_table_alias: false,
division_operator: BinaryOperator::Divide,
window_func_support_window_frame: true,
}
}

Expand All @@ -704,6 +731,7 @@ impl CustomDialectBuilder {
.supports_column_alias_in_table_alias,
requires_derived_table_alias: self.requires_derived_table_alias,
division_operator: self.division_operator,
window_func_support_window_frame: self.window_func_support_window_frame,
}
}

Expand Down Expand Up @@ -825,4 +853,12 @@ impl CustomDialectBuilder {
self.division_operator = division_operator;
self
}

pub fn with_window_func_support_window_frame(
mut self,
window_func_support_window_frame: bool,
) -> Self {
self.window_func_support_window_frame = window_func_support_window_frame;
self
}
}
91 changes: 70 additions & 21 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,18 +215,29 @@ impl Unparser<'_> {

let start_bound = self.convert_bound(&window_frame.start_bound)?;
let end_bound = self.convert_bound(&window_frame.end_bound)?;

let window_frame = if self.dialect.window_func_support_window_frame(
func_name,
&start_bound,
&end_bound,
) {
Some(ast::WindowFrame {
units,
start_bound,
end_bound: Some(end_bound),
})
} else {
None
};

let over = Some(ast::WindowType::WindowSpec(ast::WindowSpec {
window_name: None,
partition_by: partition_by
.iter()
.map(|e| self.expr_to_sql_inner(e))
.collect::<Result<Vec<_>>>()?,
order_by,
window_frame: Some(ast::WindowFrame {
units,
start_bound,
end_bound: Option::from(end_bound),
}),
window_frame,
}));

Ok(ast::Expr::Function(Function {
Expand Down Expand Up @@ -1485,6 +1496,7 @@ mod tests {
use datafusion_expr::{interval_month_day_nano_lit, ExprFunctionExt};
use datafusion_functions_aggregate::count::count_udaf;
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_functions_window::rank::rank_udwf;
use datafusion_functions_window::row_number::row_number_udwf;

use crate::unparser::dialect::{
Expand Down Expand Up @@ -2196,22 +2208,26 @@ mod tests {

#[test]
fn test_cast_value_to_binary_expr() {
let tests = [(
Expr::Cast(Cast {
expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(
"blah".to_string(),
)))),
data_type: DataType::Binary,
}),
"'blah'",
Expr::Cast(Cast {
expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(
"blah".to_string(),
)))),
data_type: DataType::BinaryView,
}),
"'blah'",
)];
let tests = [
(
Expr::Cast(Cast {
expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(
"blah".to_string(),
)))),
data_type: DataType::Binary,
}),
"'blah'",
),
(
Expr::Cast(Cast {
expr: Box::new(Expr::Literal(ScalarValue::Utf8(Some(
"blah".to_string(),
)))),
data_type: DataType::BinaryView,
}),
"'blah'",
),
];
for (value, expected) in tests {
let dialect = CustomDialectBuilder::new().build();
let unparser = Unparser::new(&dialect);
Expand Down Expand Up @@ -2512,4 +2528,37 @@ mod tests {
}
Ok(())
}

#[test]
fn test_window_func_support_window_frame() -> Result<()> {
let default_dialect: Arc<dyn Dialect> =
Arc::new(CustomDialectBuilder::new().build());

let test_dialect: Arc<dyn Dialect> = Arc::new(
CustomDialectBuilder::new()
.with_window_func_support_window_frame(false)
.build(),
);

for (dialect, expected) in [
(
default_dialect,
"rank() OVER (ORDER BY a ASC NULLS FIRST ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)",
),
(test_dialect, "rank() OVER (ORDER BY a ASC NULLS FIRST)"),
] {
let unparser = Unparser::new(dialect.as_ref());
let func = WindowFunctionDefinition::WindowUDF(rank_udwf());
let mut window_func = WindowFunction::new(func, vec![]);
window_func.order_by = vec![Sort::new(col("a"), true, true)];
let expr = Expr::WindowFunction(window_func);
let ast = unparser.expr_to_sql(&expr)?;

let actual = format!("{ast}");
let expected = format!("{expected}");

assert_eq!(actual, expected);
}
Ok(())
}
}

0 comments on commit 4e1290a

Please sign in to comment.