diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs
index d8a4fb254264..609e6f2240e1 100644
--- a/datafusion/sql/src/unparser/dialect.rs
+++ b/datafusion/sql/src/unparser/dialect.rs
@@ -18,12 +18,17 @@
use std::sync::Arc;
use arrow_schema::TimeUnit;
+use datafusion_expr::Expr;
use regex::Regex;
use sqlparser::{
- ast::{self, Ident, ObjectName, TimezoneInfo},
+ ast::{self, Function, Ident, ObjectName, TimezoneInfo},
keywords::ALL_KEYWORDS,
};
+use datafusion_common::Result;
+
+use super::{utils::date_part_to_sql, Unparser};
+
/// `Dialect` to use for Unparsing
///
/// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`)
@@ -108,6 +113,18 @@ pub trait Dialect: Send + Sync {
fn supports_column_alias_in_table_alias(&self) -> bool {
true
}
+
+ /// Allows the dialect to override scalar function unparsing if the dialect has specific rules.
+ /// Returns None if the default unparsing should be used, or Some(ast::Expr) if there is
+ /// a custom implementation for the function.
+ fn scalar_function_to_sql_overrides(
+ &self,
+ _unparser: &Unparser,
+ _func_name: &str,
+ _args: &[Expr],
+ ) -> Result> {
+ Ok(None)
+ }
}
/// `IntervalStyle` to use for unparsing
@@ -171,6 +188,67 @@ impl Dialect for PostgreSqlDialect {
fn float64_ast_dtype(&self) -> sqlparser::ast::DataType {
sqlparser::ast::DataType::DoublePrecision
}
+
+ fn scalar_function_to_sql_overrides(
+ &self,
+ unparser: &Unparser,
+ func_name: &str,
+ args: &[Expr],
+ ) -> Result > {
+ if func_name == "round" {
+ return Ok(Some(
+ self.round_to_sql_enforce_numeric(unparser, func_name, args)?,
+ ));
+ }
+
+ Ok(None)
+ }
+}
+
+impl PostgreSqlDialect {
+ fn round_to_sql_enforce_numeric(
+ &self,
+ unparser: &Unparser,
+ func_name: &str,
+ args: &[Expr],
+ ) -> Result {
+ let mut args = unparser.function_args_to_sql(args)?;
+
+ // Enforce the first argument to be Numeric
+ if let Some(ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(expr))) =
+ args.first_mut()
+ {
+ if let ast::Expr::Cast { data_type, .. } = expr {
+ // Don't create an additional cast wrapper if we can update the existing one
+ *data_type = ast::DataType::Numeric(ast::ExactNumberInfo::None);
+ } else {
+ // Wrap the expression in a new cast
+ *expr = ast::Expr::Cast {
+ kind: ast::CastKind::Cast,
+ expr: Box::new(expr.clone()),
+ data_type: ast::DataType::Numeric(ast::ExactNumberInfo::None),
+ format: None,
+ };
+ }
+ }
+
+ Ok(ast::Expr::Function(Function {
+ name: ast::ObjectName(vec![Ident {
+ value: func_name.to_string(),
+ quote_style: None,
+ }]),
+ args: ast::FunctionArguments::List(ast::FunctionArgumentList {
+ duplicate_treatment: None,
+ args,
+ clauses: vec![],
+ }),
+ filter: None,
+ null_treatment: None,
+ over: None,
+ within_group: vec![],
+ parameters: ast::FunctionArguments::None,
+ }))
+ }
}
pub struct MySqlDialect {}
@@ -211,6 +289,19 @@ impl Dialect for MySqlDialect {
) -> ast::DataType {
ast::DataType::Datetime(None)
}
+
+ fn scalar_function_to_sql_overrides(
+ &self,
+ unparser: &Unparser,
+ func_name: &str,
+ args: &[Expr],
+ ) -> Result> {
+ if func_name == "date_part" {
+ return date_part_to_sql(unparser, self.date_field_extract_style(), args);
+ }
+
+ Ok(None)
+ }
}
pub struct SqliteDialect {}
@@ -231,6 +322,19 @@ impl Dialect for SqliteDialect {
fn supports_column_alias_in_table_alias(&self) -> bool {
false
}
+
+ fn scalar_function_to_sql_overrides(
+ &self,
+ unparser: &Unparser,
+ func_name: &str,
+ args: &[Expr],
+ ) -> Result > {
+ if func_name == "date_part" {
+ return date_part_to_sql(unparser, self.date_field_extract_style(), args);
+ }
+
+ Ok(None)
+ }
}
pub struct CustomDialect {
@@ -339,6 +443,19 @@ impl Dialect for CustomDialect {
fn supports_column_alias_in_table_alias(&self) -> bool {
self.supports_column_alias_in_table_alias
}
+
+ fn scalar_function_to_sql_overrides(
+ &self,
+ unparser: &Unparser,
+ func_name: &str,
+ args: &[Expr],
+ ) -> Result > {
+ if func_name == "date_part" {
+ return date_part_to_sql(unparser, self.date_field_extract_style(), args);
+ }
+
+ Ok(None)
+ }
}
/// `CustomDialectBuilder` to build `CustomDialect` using builder pattern
diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs
index b924268a7657..537ac2274424 100644
--- a/datafusion/sql/src/unparser/expr.rs
+++ b/datafusion/sql/src/unparser/expr.rs
@@ -15,16 +15,15 @@
// specific language governing permissions and limitations
// under the License.
-use datafusion_expr::ScalarUDF;
use sqlparser::ast::Value::SingleQuotedString;
use sqlparser::ast::{
- self, BinaryOperator, Expr as AstExpr, Function, FunctionArg, Ident, Interval,
- ObjectName, TimezoneInfo, UnaryOperator,
+ self, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, ObjectName,
+ TimezoneInfo, UnaryOperator,
};
use std::sync::Arc;
use std::vec;
-use super::dialect::{DateFieldExtractStyle, IntervalStyle};
+use super::dialect::IntervalStyle;
use super::Unparser;
use arrow::datatypes::{Decimal128Type, Decimal256Type, DecimalType};
use arrow::util::display::array_value_to_string;
@@ -116,47 +115,14 @@ impl Unparser<'_> {
Expr::ScalarFunction(ScalarFunction { func, args }) => {
let func_name = func.name();
- if let Some(expr) =
- self.scalar_function_to_sql_overrides(func_name, func, args)
+ if let Some(expr) = self
+ .dialect
+ .scalar_function_to_sql_overrides(self, func_name, args)?
{
return Ok(expr);
}
- let args = args
- .iter()
- .map(|e| {
- if matches!(
- e,
- Expr::Wildcard {
- qualifier: None,
- ..
- }
- ) {
- Ok(FunctionArg::Unnamed(ast::FunctionArgExpr::Wildcard))
- } else {
- self.expr_to_sql_inner(e).map(|e| {
- FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(e))
- })
- }
- })
- .collect::>>()?;
-
- Ok(ast::Expr::Function(Function {
- name: ast::ObjectName(vec![Ident {
- value: func_name.to_string(),
- quote_style: None,
- }]),
- args: ast::FunctionArguments::List(ast::FunctionArgumentList {
- duplicate_treatment: None,
- args,
- clauses: vec![],
- }),
- filter: None,
- null_treatment: None,
- over: None,
- within_group: vec![],
- parameters: ast::FunctionArguments::None,
- }))
+ self.scalar_function_to_sql(func_name, args)
}
Expr::Between(Between {
expr,
@@ -508,6 +474,30 @@ impl Unparser<'_> {
}
}
+ pub fn scalar_function_to_sql(
+ &self,
+ func_name: &str,
+ args: &[Expr],
+ ) -> Result {
+ let args = self.function_args_to_sql(args)?;
+ Ok(ast::Expr::Function(Function {
+ name: ast::ObjectName(vec![Ident {
+ value: func_name.to_string(),
+ quote_style: None,
+ }]),
+ args: ast::FunctionArguments::List(ast::FunctionArgumentList {
+ duplicate_treatment: None,
+ args,
+ clauses: vec![],
+ }),
+ filter: None,
+ null_treatment: None,
+ over: None,
+ within_group: vec![],
+ parameters: ast::FunctionArguments::None,
+ }))
+ }
+
pub fn sort_to_sql(&self, sort: &Sort) -> Result {
let Sort {
expr,
@@ -530,87 +520,6 @@ impl Unparser<'_> {
})
}
- fn scalar_function_to_sql_overrides(
- &self,
- func_name: &str,
- _func: &Arc,
- args: &[Expr],
- ) -> Option {
- if func_name.to_lowercase() == "date_part" {
- match (self.dialect.date_field_extract_style(), args.len()) {
- (DateFieldExtractStyle::Extract, 2) => {
- let date_expr = self.expr_to_sql(&args[1]).ok()?;
-
- if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &args[0] {
- let field = match field.to_lowercase().as_str() {
- "year" => ast::DateTimeField::Year,
- "month" => ast::DateTimeField::Month,
- "day" => ast::DateTimeField::Day,
- "hour" => ast::DateTimeField::Hour,
- "minute" => ast::DateTimeField::Minute,
- "second" => ast::DateTimeField::Second,
- _ => return None,
- };
-
- return Some(ast::Expr::Extract {
- field,
- expr: Box::new(date_expr),
- syntax: ast::ExtractSyntax::From,
- });
- }
- }
- (DateFieldExtractStyle::Strftime, 2) => {
- let column = self.expr_to_sql(&args[1]).ok()?;
-
- if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &args[0] {
- let field = match field.to_lowercase().as_str() {
- "year" => "%Y",
- "month" => "%m",
- "day" => "%d",
- "hour" => "%H",
- "minute" => "%M",
- "second" => "%S",
- _ => return None,
- };
-
- return Some(ast::Expr::Function(ast::Function {
- name: ast::ObjectName(vec![ast::Ident {
- value: "strftime".to_string(),
- quote_style: None,
- }]),
- args: ast::FunctionArguments::List(
- ast::FunctionArgumentList {
- duplicate_treatment: None,
- args: vec![
- ast::FunctionArg::Unnamed(
- ast::FunctionArgExpr::Expr(ast::Expr::Value(
- ast::Value::SingleQuotedString(
- field.to_string(),
- ),
- )),
- ),
- ast::FunctionArg::Unnamed(
- ast::FunctionArgExpr::Expr(column),
- ),
- ],
- clauses: vec![],
- },
- ),
- filter: None,
- null_treatment: None,
- over: None,
- within_group: vec![],
- parameters: ast::FunctionArguments::None,
- }));
- }
- }
- _ => {} // no overrides for DateFieldExtractStyle::DatePart, because it's already a date_part
- }
- }
-
- None
- }
-
fn ast_type_for_date64_in_cast(&self) -> ast::DataType {
if self.dialect.use_timestamp_for_date64() {
ast::DataType::Timestamp(None, ast::TimezoneInfo::None)
@@ -665,7 +574,10 @@ impl Unparser<'_> {
}
}
- fn function_args_to_sql(&self, args: &[Expr]) -> Result> {
+ pub(crate) fn function_args_to_sql(
+ &self,
+ args: &[Expr],
+ ) -> Result> {
args.iter()
.map(|e| {
if matches!(
@@ -1554,7 +1466,10 @@ mod tests {
use datafusion_functions_aggregate::expr_fn::sum;
use datafusion_functions_window::row_number::row_number_udwf;
- use crate::unparser::dialect::{CustomDialect, CustomDialectBuilder};
+ use crate::unparser::dialect::{
+ CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, Dialect,
+ PostgreSqlDialect,
+ };
use super::*;
@@ -2428,4 +2343,39 @@ mod tests {
assert_eq!(actual, expected);
}
}
+
+ #[test]
+ fn test_round_scalar_fn_to_expr() -> Result<()> {
+ let default_dialect: Arc = Arc::new(
+ CustomDialectBuilder::new()
+ .with_identifier_quote_style('"')
+ .build(),
+ );
+ let postgres_dialect: Arc = Arc::new(PostgreSqlDialect {});
+
+ for (dialect, identifier) in
+ [(default_dialect, "DOUBLE"), (postgres_dialect, "NUMERIC")]
+ {
+ let unparser = Unparser::new(dialect.as_ref());
+ let expr = Expr::ScalarFunction(ScalarFunction {
+ func: Arc::new(ScalarUDF::from(
+ datafusion_functions::math::round::RoundFunc::new(),
+ )),
+ args: vec![
+ Expr::Cast(Cast {
+ expr: Box::new(col("a")),
+ data_type: DataType::Float64,
+ }),
+ Expr::Literal(ScalarValue::Int64(Some(2))),
+ ],
+ });
+ let ast = unparser.expr_to_sql(&expr)?;
+
+ let actual = format!("{}", ast);
+ let expected = format!(r#"round(CAST("a" AS {identifier}), 2)"#);
+
+ assert_eq!(actual, expected);
+ }
+ Ok(())
+ }
}
diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs
index 0059aba25738..8b2530a7499b 100644
--- a/datafusion/sql/src/unparser/utils.rs
+++ b/datafusion/sql/src/unparser/utils.rs
@@ -18,11 +18,14 @@
use datafusion_common::{
internal_err,
tree_node::{Transformed, TreeNode},
- Column, DataFusionError, Result,
+ Column, DataFusionError, Result, ScalarValue,
};
use datafusion_expr::{
utils::grouping_set_to_exprlist, Aggregate, Expr, LogicalPlan, Window,
};
+use sqlparser::ast;
+
+use super::{dialect::DateFieldExtractStyle, Unparser};
/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
@@ -187,3 +190,80 @@ fn find_window_expr<'a>(
.flat_map(|w| w.window_expr.iter())
.find(|expr| expr.schema_name().to_string() == column_name)
}
+
+/// Converts a date_part function to SQL, tailoring it to the supported date field extraction style.
+pub(crate) fn date_part_to_sql(
+ unparser: &Unparser,
+ style: DateFieldExtractStyle,
+ date_part_args: &[Expr],
+) -> Result> {
+ match (style, date_part_args.len()) {
+ (DateFieldExtractStyle::Extract, 2) => {
+ let date_expr = unparser.expr_to_sql(&date_part_args[1])?;
+ if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] {
+ let field = match field.to_lowercase().as_str() {
+ "year" => ast::DateTimeField::Year,
+ "month" => ast::DateTimeField::Month,
+ "day" => ast::DateTimeField::Day,
+ "hour" => ast::DateTimeField::Hour,
+ "minute" => ast::DateTimeField::Minute,
+ "second" => ast::DateTimeField::Second,
+ _ => return Ok(None),
+ };
+
+ return Ok(Some(ast::Expr::Extract {
+ field,
+ expr: Box::new(date_expr),
+ syntax: ast::ExtractSyntax::From,
+ }));
+ }
+ }
+ (DateFieldExtractStyle::Strftime, 2) => {
+ let column = unparser.expr_to_sql(&date_part_args[1])?;
+
+ if let Expr::Literal(ScalarValue::Utf8(Some(field))) = &date_part_args[0] {
+ let field = match field.to_lowercase().as_str() {
+ "year" => "%Y",
+ "month" => "%m",
+ "day" => "%d",
+ "hour" => "%H",
+ "minute" => "%M",
+ "second" => "%S",
+ _ => return Ok(None),
+ };
+
+ return Ok(Some(ast::Expr::Function(ast::Function {
+ name: ast::ObjectName(vec![ast::Ident {
+ value: "strftime".to_string(),
+ quote_style: None,
+ }]),
+ args: ast::FunctionArguments::List(ast::FunctionArgumentList {
+ duplicate_treatment: None,
+ args: vec![
+ ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(
+ ast::Expr::Value(ast::Value::SingleQuotedString(
+ field.to_string(),
+ )),
+ )),
+ ast::FunctionArg::Unnamed(ast::FunctionArgExpr::Expr(column)),
+ ],
+ clauses: vec![],
+ }),
+ filter: None,
+ null_treatment: None,
+ over: None,
+ within_group: vec![],
+ parameters: ast::FunctionArguments::None,
+ })));
+ }
+ }
+ (DateFieldExtractStyle::DatePart, _) => {
+ return Ok(Some(
+ unparser.scalar_function_to_sql("date_part", date_part_args)?,
+ ));
+ }
+ _ => {}
+ };
+
+ Ok(None)
+}