Skip to content

Commit

Permalink
Support alternate format for Int64 unparsing (SIGNED for MySQL) (#22)
Browse files Browse the repository at this point in the history
  • Loading branch information
sgrebnov authored Jul 19, 2024
1 parent d0943b9 commit 790f9c6
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 4 deletions.
33 changes: 30 additions & 3 deletions datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
// under the License.

use regex::Regex;
use sqlparser::keywords::ALL_KEYWORDS;
use sqlparser::{
ast::{self, Ident, ObjectName},
keywords::ALL_KEYWORDS,
};

/// `Dialect` to use for Unparsing
///
Expand Down Expand Up @@ -61,6 +64,12 @@ pub trait Dialect: Send + Sync {
fn date_field_extract_style(&self) -> DateFieldExtractStyle {
DateFieldExtractStyle::DatePart
}

// The SQL type to use for Arrow Int64 unparsing
// Most dialects use BigInt, but some, like MySQL, require SIGNED
fn int64_cast_dtype(&self) -> ast::DataType {
ast::DataType::BigInt(None)
}
}

/// `IntervalStyle` to use for unparsing
Expand All @@ -79,7 +88,7 @@ pub enum IntervalStyle {
}

/// Datetime subfield extraction style for unparsing
///
///
/// https://www.postgresql.org/docs/current/functions-datetime.html#FUNCTIONS-DATETIME-EXTRACT
/// Different DBMSs follow different standards; popular ones are:
/// date_part('YEAR', date '2001-02-16')
Expand All @@ -88,7 +97,7 @@ pub enum IntervalStyle {
#[derive(Clone, Copy, PartialEq)]
pub enum DateFieldExtractStyle {
DatePart,
Extract
Extract,
}

pub struct DefaultDialect {}
Expand Down Expand Up @@ -144,6 +153,10 @@ impl Dialect for MySqlDialect {
fn date_field_extract_style(&self) -> DateFieldExtractStyle {
DateFieldExtractStyle::Extract
}

fn int64_cast_dtype(&self) -> ast::DataType {
ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![])
}
}

pub struct SqliteDialect {}
Expand All @@ -162,6 +175,7 @@ pub struct CustomDialect {
use_double_precision_for_float64: bool,
use_char_for_utf8_cast: bool,
date_subfield_extract_style: DateFieldExtractStyle,
int64_cast_dtype: ast::DataType,
}

impl Default for CustomDialect {
Expand All @@ -174,6 +188,7 @@ impl Default for CustomDialect {
use_double_precision_for_float64: false,
use_char_for_utf8_cast: false,
date_subfield_extract_style: DateFieldExtractStyle::DatePart,
int64_cast_dtype: ast::DataType::BigInt(None),
}
}
}
Expand Down Expand Up @@ -216,6 +231,10 @@ impl Dialect for CustomDialect {
fn date_field_extract_style(&self) -> DateFieldExtractStyle {
self.date_subfield_extract_style
}

fn int64_cast_dtype(&self) -> ast::DataType {
self.int64_cast_dtype.clone()
}
}

// create a CustomDialectBuilder
Expand All @@ -227,6 +246,7 @@ pub struct CustomDialectBuilder {
use_double_precision_for_float64: bool,
use_char_for_utf8_cast: bool,
date_subfield_extract_style: DateFieldExtractStyle,
int64_cast_dtype: ast::DataType,
}

impl CustomDialectBuilder {
Expand All @@ -239,6 +259,7 @@ impl CustomDialectBuilder {
use_double_precision_for_float64: false,
use_char_for_utf8_cast: false,
date_subfield_extract_style: DateFieldExtractStyle::DatePart,
int64_cast_dtype: ast::DataType::BigInt(None),
}
}

Expand All @@ -251,6 +272,7 @@ impl CustomDialectBuilder {
use_double_precision_for_float64: self.use_double_precision_for_float64,
use_char_for_utf8_cast: self.use_char_for_utf8_cast,
date_subfield_extract_style: self.date_subfield_extract_style,
int64_cast_dtype: self.int64_cast_dtype,
}
}

Expand Down Expand Up @@ -300,4 +322,9 @@ impl CustomDialectBuilder {
self.date_subfield_extract_style = date_subfield_extract_style;
self
}

pub fn with_int64_cast_dtype(mut self, int64_cast_dtype: ast::DataType) -> Self {
self.int64_cast_dtype = int64_cast_dtype;
self
}
}
31 changes: 30 additions & 1 deletion datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1272,7 +1272,7 @@ impl Unparser<'_> {
DataType::Int8 => Ok(ast::DataType::TinyInt(None)),
DataType::Int16 => Ok(ast::DataType::SmallInt(None)),
DataType::Int32 => Ok(ast::DataType::Integer(None)),
DataType::Int64 => Ok(ast::DataType::BigInt(None)),
DataType::Int64 => Ok(self.dialect.int64_cast_dtype()),
DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)),
DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)),
DataType::UInt32 => Ok(ast::DataType::UnsignedInteger(None)),
Expand Down Expand Up @@ -1388,6 +1388,7 @@ mod tests {
use arrow::datatypes::TimeUnit;
use arrow::datatypes::{Field, Schema};
use arrow_schema::DataType::Int8;
use ast::ObjectName;
use datafusion_common::TableReference;
use datafusion_expr::{
case, col, cube, exists, grouping_set, interval_datetime_lit,
Expand Down Expand Up @@ -2129,4 +2130,32 @@ mod tests {
}
Ok(())
}

#[test]
fn custom_dialect_with_int64_cast_dtype() -> Result<()> {
let default_dialect = CustomDialectBuilder::new().build();
let mysql_dialect = CustomDialectBuilder::new()
.with_int64_cast_dtype(ast::DataType::Custom(
ObjectName(vec![Ident::new("SIGNED")]),
vec![],
))
.build();

for (dialect, identifier) in
[(default_dialect, "BIGINT"), (mysql_dialect, "SIGNED")]
{
let unparser = Unparser::new(&dialect);
let expr = Expr::Cast(Cast {
expr: Box::new(col("a")),
data_type: DataType::Int64,
});
let ast = unparser.expr_to_sql(&expr)?;

let actual = format!("{}", ast);
let expected = format!(r#"CAST(a AS {identifier})"#);

assert_eq!(actual, expected);
}
Ok(())
}
}

0 comments on commit 790f9c6

Please sign in to comment.