Skip to content

Commit

Permalink
Preserve double colon casts (and simplify cast representations) (apac…
Browse files Browse the repository at this point in the history
  • Loading branch information
jmhain authored and JichaoS committed May 7, 2024
1 parent e5fb1e2 commit d1ee270
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 88 deletions.
90 changes: 45 additions & 45 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,26 @@ impl fmt::Display for MapAccessKey {
}
}

/// The syntax used for in a cast expression.
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub enum CastKind {
/// The standard SQL cast syntax, e.g. `CAST(<expr> as <datatype>)`
Cast,
/// A cast that returns `NULL` on failure, e.g. `TRY_CAST(<expr> as <datatype>)`.
///
/// See <https://docs.snowflake.com/en/sql-reference/functions/try_cast>.
/// See <https://learn.microsoft.com/en-us/sql/t-sql/functions/try-cast-transact-sql>.
TryCast,
/// A cast that returns `NULL` on failure, bigQuery-specific , e.g. `SAFE_CAST(<expr> as <datatype>)`.
///
/// See <https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-and-operators#safe_casting>.
SafeCast,
/// `<expr> :: <datatype>`
DoubleColon,
}

/// An SQL expression of any type.
///
/// The parser does not distinguish between expressions of different types
Expand Down Expand Up @@ -546,25 +566,7 @@ pub enum Expr {
},
/// `CAST` an expression to a different data type e.g. `CAST(foo AS VARCHAR(123))`
Cast {
expr: Box<Expr>,
data_type: DataType,
// Optional CAST(string_expression AS type FORMAT format_string_expression) as used by BigQuery
// https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#formatting_syntax
format: Option<CastFormat>,
},
/// `TRY_CAST` an expression to a different data type e.g. `TRY_CAST(foo AS VARCHAR(123))`
// this differs from CAST in the choice of how to implement invalid conversions
TryCast {
expr: Box<Expr>,
data_type: DataType,
// Optional CAST(string_expression AS type FORMAT format_string_expression) as used by BigQuery
// https://cloud.google.com/bigquery/docs/reference/standard-sql/format-elements#formatting_syntax
format: Option<CastFormat>,
},
/// `SAFE_CAST` an expression to a different data type e.g. `SAFE_CAST(foo AS FLOAT64)`
// only available for BigQuery: https://cloud.google.com/bigquery/docs/reference/standard-sql/functions-and-operators#safe_casting
// this works the same as `TRY_CAST`
SafeCast {
kind: CastKind,
expr: Box<Expr>,
data_type: DataType,
// Optional CAST(string_expression AS type FORMAT format_string_expression) as used by BigQuery
Expand Down Expand Up @@ -989,38 +991,36 @@ impl fmt::Display for Expr {
write!(f, ")")
}
Expr::Cast {
kind,
expr,
data_type,
format,
} => {
if let Some(format) = format {
write!(f, "CAST({expr} AS {data_type} FORMAT {format})")
} else {
write!(f, "CAST({expr} AS {data_type})")
} => match kind {
CastKind::Cast => {
if let Some(format) = format {
write!(f, "CAST({expr} AS {data_type} FORMAT {format})")
} else {
write!(f, "CAST({expr} AS {data_type})")
}
}
}
Expr::TryCast {
expr,
data_type,
format,
} => {
if let Some(format) = format {
write!(f, "TRY_CAST({expr} AS {data_type} FORMAT {format})")
} else {
write!(f, "TRY_CAST({expr} AS {data_type})")
CastKind::TryCast => {
if let Some(format) = format {
write!(f, "TRY_CAST({expr} AS {data_type} FORMAT {format})")
} else {
write!(f, "TRY_CAST({expr} AS {data_type})")
}
}
}
Expr::SafeCast {
expr,
data_type,
format,
} => {
if let Some(format) = format {
write!(f, "SAFE_CAST({expr} AS {data_type} FORMAT {format})")
} else {
write!(f, "SAFE_CAST({expr} AS {data_type})")
CastKind::SafeCast => {
if let Some(format) = format {
write!(f, "SAFE_CAST({expr} AS {data_type} FORMAT {format})")
} else {
write!(f, "SAFE_CAST({expr} AS {data_type})")
}
}
}
CastKind::DoubleColon => {
write!(f, "{expr}::{data_type}")
}
},
Expr::Extract { field, expr } => write!(f, "EXTRACT({field} FROM {expr})"),
Expr::Ceil { expr, field } => {
if field == &DateTimeField::NoDateTime {
Expand Down
47 changes: 12 additions & 35 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1004,9 +1004,9 @@ impl<'a> Parser<'a> {
}
Keyword::CASE => self.parse_case_expr(),
Keyword::CONVERT => self.parse_convert_expr(),
Keyword::CAST => self.parse_cast_expr(),
Keyword::TRY_CAST => self.parse_try_cast_expr(),
Keyword::SAFE_CAST => self.parse_safe_cast_expr(),
Keyword::CAST => self.parse_cast_expr(CastKind::Cast),
Keyword::TRY_CAST => self.parse_cast_expr(CastKind::TryCast),
Keyword::SAFE_CAST => self.parse_cast_expr(CastKind::SafeCast),
Keyword::EXISTS => self.parse_exists_expr(false),
Keyword::EXTRACT => self.parse_extract_expr(),
Keyword::CEIL => self.parse_ceil_floor_expr(true),
Expand Down Expand Up @@ -1491,44 +1491,15 @@ impl<'a> Parser<'a> {
}

/// Parse a SQL CAST function e.g. `CAST(expr AS FLOAT)`
pub fn parse_cast_expr(&mut self) -> Result<Expr, ParserError> {
pub fn parse_cast_expr(&mut self, kind: CastKind) -> Result<Expr, ParserError> {
self.expect_token(&Token::LParen)?;
let expr = self.parse_expr()?;
self.expect_keyword(Keyword::AS)?;
let data_type = self.parse_data_type()?;
let format = self.parse_optional_cast_format()?;
self.expect_token(&Token::RParen)?;
Ok(Expr::Cast {
expr: Box::new(expr),
data_type,
format,
})
}

/// Parse a SQL TRY_CAST function e.g. `TRY_CAST(expr AS FLOAT)`
pub fn parse_try_cast_expr(&mut self) -> Result<Expr, ParserError> {
self.expect_token(&Token::LParen)?;
let expr = self.parse_expr()?;
self.expect_keyword(Keyword::AS)?;
let data_type = self.parse_data_type()?;
let format = self.parse_optional_cast_format()?;
self.expect_token(&Token::RParen)?;
Ok(Expr::TryCast {
expr: Box::new(expr),
data_type,
format,
})
}

/// Parse a BigQuery SAFE_CAST function e.g. `SAFE_CAST(expr AS FLOAT64)`
pub fn parse_safe_cast_expr(&mut self) -> Result<Expr, ParserError> {
self.expect_token(&Token::LParen)?;
let expr = self.parse_expr()?;
self.expect_keyword(Keyword::AS)?;
let data_type = self.parse_data_type()?;
let format = self.parse_optional_cast_format()?;
self.expect_token(&Token::RParen)?;
Ok(Expr::SafeCast {
kind,
expr: Box::new(expr),
data_type,
format,
Expand Down Expand Up @@ -2528,7 +2499,12 @@ impl<'a> Parser<'a> {
),
}
} else if Token::DoubleColon == tok {
self.parse_pg_cast(expr)
Ok(Expr::Cast {
kind: CastKind::DoubleColon,
expr: Box::new(expr),
data_type: self.parse_data_type()?,
format: None,
})
} else if Token::ExclamationMark == tok {
// PostgreSQL factorial operation
Ok(Expr::UnaryOp {
Expand Down Expand Up @@ -2702,6 +2678,7 @@ impl<'a> Parser<'a> {
/// Parse a postgresql casting style which is in the form of `expr::datatype`
pub fn parse_pg_cast(&mut self, expr: Expr) -> Result<Expr, ParserError> {
Ok(Expr::Cast {
kind: CastKind::DoubleColon,
expr: Box::new(expr),
data_type: self.parse_data_type()?,
format: None,
Expand Down
13 changes: 12 additions & 1 deletion tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2107,6 +2107,7 @@ fn parse_cast() {
let select = verified_only_select(sql);
assert_eq!(
&Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Identifier(Ident::new("id"))),
data_type: DataType::BigInt(None),
format: None,
Expand All @@ -2118,6 +2119,7 @@ fn parse_cast() {
let select = verified_only_select(sql);
assert_eq!(
&Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Identifier(Ident::new("id"))),
data_type: DataType::TinyInt(None),
format: None,
Expand Down Expand Up @@ -2145,6 +2147,7 @@ fn parse_cast() {
let select = verified_only_select(sql);
assert_eq!(
&Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Identifier(Ident::new("id"))),
data_type: DataType::Nvarchar(Some(50)),
format: None,
Expand All @@ -2156,6 +2159,7 @@ fn parse_cast() {
let select = verified_only_select(sql);
assert_eq!(
&Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Identifier(Ident::new("id"))),
data_type: DataType::Clob(None),
format: None,
Expand All @@ -2167,6 +2171,7 @@ fn parse_cast() {
let select = verified_only_select(sql);
assert_eq!(
&Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Identifier(Ident::new("id"))),
data_type: DataType::Clob(Some(50)),
format: None,
Expand All @@ -2178,6 +2183,7 @@ fn parse_cast() {
let select = verified_only_select(sql);
assert_eq!(
&Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Identifier(Ident::new("id"))),
data_type: DataType::Binary(Some(50)),
format: None,
Expand All @@ -2189,6 +2195,7 @@ fn parse_cast() {
let select = verified_only_select(sql);
assert_eq!(
&Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Identifier(Ident::new("id"))),
data_type: DataType::Varbinary(Some(50)),
format: None,
Expand All @@ -2200,6 +2207,7 @@ fn parse_cast() {
let select = verified_only_select(sql);
assert_eq!(
&Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Identifier(Ident::new("id"))),
data_type: DataType::Blob(None),
format: None,
Expand All @@ -2211,6 +2219,7 @@ fn parse_cast() {
let select = verified_only_select(sql);
assert_eq!(
&Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Identifier(Ident::new("id"))),
data_type: DataType::Blob(Some(50)),
format: None,
Expand All @@ -2222,6 +2231,7 @@ fn parse_cast() {
let select = verified_only_select(sql);
assert_eq!(
&Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Identifier(Ident::new("details"))),
data_type: DataType::JSONB,
format: None,
Expand All @@ -2235,7 +2245,8 @@ fn parse_try_cast() {
let sql = "SELECT TRY_CAST(id AS BIGINT) FROM customer";
let select = verified_only_select(sql);
assert_eq!(
&Expr::TryCast {
&Expr::Cast {
kind: CastKind::TryCast,
expr: Box::new(Expr::Identifier(Ident::new("id"))),
data_type: DataType::BigInt(None),
format: None,
Expand Down
14 changes: 8 additions & 6 deletions tests/sqlparser_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ fn parse_create_table_with_defaults() {
location: None,
..
} => {
use pretty_assertions::assert_eq;
assert_eq!("public.customer", name.to_string());
assert_eq!(
columns,
Expand Down Expand Up @@ -422,9 +423,7 @@ fn parse_create_table_with_defaults() {
options: vec![
ColumnOptionDef {
name: None,
option: ColumnOption::Default(
pg().verified_expr("CAST(now() AS TEXT)")
)
option: ColumnOption::Default(pg().verified_expr("now()::TEXT"))
},
ColumnOptionDef {
name: None,
Expand Down Expand Up @@ -498,15 +497,15 @@ fn parse_create_table_from_pg_dump() {
active int
)";
pg().one_statement_parses_to(sql, "CREATE TABLE public.customer (\
customer_id INTEGER DEFAULT nextval(CAST('public.customer_customer_id_seq' AS REGCLASS)) NOT NULL, \
customer_id INTEGER DEFAULT nextval('public.customer_customer_id_seq'::REGCLASS) NOT NULL, \
store_id SMALLINT NOT NULL, \
first_name CHARACTER VARYING(45) NOT NULL, \
last_name CHARACTER VARYING(45) NOT NULL, \
info TEXT[], \
address_id SMALLINT NOT NULL, \
activebool BOOLEAN DEFAULT true NOT NULL, \
create_date DATE DEFAULT CAST(now() AS DATE) NOT NULL, \
create_date1 DATE DEFAULT CAST(CAST('now' AS TEXT) AS DATE) NOT NULL, \
create_date DATE DEFAULT now()::DATE NOT NULL, \
create_date1 DATE DEFAULT 'now'::TEXT::DATE NOT NULL, \
last_update TIMESTAMP WITHOUT TIME ZONE DEFAULT now(), \
release_year public.year, \
active INT\
Expand Down Expand Up @@ -1448,11 +1447,13 @@ fn parse_execute() {
parameters: vec![],
using: vec![
Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Value(Value::Number("1337".parse().unwrap(), false))),
data_type: DataType::SmallInt(None),
format: None
},
Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Value(Value::Number("7331".parse().unwrap(), false))),
data_type: DataType::SmallInt(None),
format: None
Expand Down Expand Up @@ -1908,6 +1909,7 @@ fn parse_array_index_expr() {
assert_eq!(
&Expr::ArrayIndex {
obj: Box::new(Expr::Nested(Box::new(Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Array(Array {
elem: vec![Expr::Array(Array {
elem: vec![num[2].clone(), num[3].clone(),],
Expand Down
3 changes: 2 additions & 1 deletion tests/sqlparser_snowflake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ fn parse_array() {
let select = snowflake().verified_only_select(sql);
assert_eq!(
&Expr::Cast {
kind: CastKind::Cast,
expr: Box::new(Expr::Identifier(Ident::new("a"))),
data_type: DataType::Array(ArrayElemTypeDef::None),
format: None,
Expand Down Expand Up @@ -228,7 +229,7 @@ fn parse_json_using_colon() {
select.projection[0]
);

snowflake().one_statement_parses_to("SELECT a:b::int FROM t", "SELECT CAST(a:b AS INT) FROM t");
snowflake().verified_stmt("SELECT a:b::INT FROM t");

let sql = "SELECT a:start, a:end FROM t";
let select = snowflake().verified_only_select(sql);
Expand Down

0 comments on commit d1ee270

Please sign in to comment.