Skip to content

Commit 37f0f14

Browse files
authored
Allow adding custom scalar to sql overrides for DuckDB (#68)
1 parent 2b93a30 commit 37f0f14

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

datafusion/sql/src/unparser/dialect.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::Arc;
18+
use std::{collections::HashMap, sync::Arc};
1919

2020
use arrow_schema::TimeUnit;
2121
use datafusion_expr::Expr;
@@ -29,6 +29,9 @@ use datafusion_common::Result;
2929

3030
use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser};
3131

32+
pub type ScalarFnToSqlHandler =
33+
Box<dyn Fn(&Unparser, &[Expr]) -> Result<Option<ast::Expr>> + Send + Sync>;
34+
3235
/// `Dialect` to use for Unparsing
3336
///
3437
/// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`)
@@ -286,7 +289,30 @@ impl PostgreSqlDialect {
286289
}
287290
}
288291

289-
pub struct DuckDBDialect {}
292+
#[derive(Default)]
293+
pub struct DuckDBDialect {
294+
custom_scalar_fn_overrides: HashMap<String, ScalarFnToSqlHandler>,
295+
}
296+
297+
impl DuckDBDialect {
298+
#[must_use]
299+
pub fn new() -> Self {
300+
Self {
301+
custom_scalar_fn_overrides: HashMap::new(),
302+
}
303+
}
304+
305+
pub fn with_custom_scalar_overrides(
306+
mut self,
307+
handlers: Vec<(&str, ScalarFnToSqlHandler)>,
308+
) -> Self {
309+
for (func_name, handler) in handlers {
310+
self.custom_scalar_fn_overrides
311+
.insert(func_name.to_string(), handler);
312+
}
313+
self
314+
}
315+
}
290316

291317
impl Dialect for DuckDBDialect {
292318
fn identifier_quote_style(&self, _: &str) -> Option<char> {
@@ -307,6 +333,10 @@ impl Dialect for DuckDBDialect {
307333
func_name: &str,
308334
args: &[Expr],
309335
) -> Result<Option<ast::Expr>> {
336+
if let Some(handler) = self.custom_scalar_fn_overrides.get(func_name) {
337+
return handler(unparser, args);
338+
}
339+
310340
if func_name == "character_length" {
311341
return character_length_to_sql(
312342
unparser,

0 commit comments

Comments
 (0)