Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions src/ast/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,22 @@ pub enum CeilFloorKind {
Scale(Value),
}

/// A WHEN clause in a CASE expression containing both
/// the condition and its corresponding result
#[derive(Debug, Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "visitor", derive(Visit, VisitMut))]
pub struct CaseWhen {
pub condition: Expr,
pub result: Expr,
}

impl fmt::Display for CaseWhen {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WHEN {} THEN {}", self.condition, self.result)
}
}

/// An SQL expression of any type.
///
/// # Semantics / Type Checking
Expand Down Expand Up @@ -918,8 +934,7 @@ pub enum Expr {
/// <https://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#simple-when-clause>
Case {
operand: Option<Box<Expr>>,
conditions: Vec<Expr>,
results: Vec<Expr>,
conditions: Vec<CaseWhen>,
else_result: Option<Box<Expr>>,
},
/// An exists expression `[ NOT ] EXISTS(SELECT ...)`, used in expressions like
Expand Down Expand Up @@ -1621,17 +1636,15 @@ impl fmt::Display for Expr {
Expr::Case {
operand,
conditions,
results,
else_result,
} => {
write!(f, "CASE")?;
if let Some(operand) = operand {
write!(f, " {operand}")?;
}
for (c, r) in conditions.iter().zip(results) {
write!(f, " WHEN {c} THEN {r}")?;
for when in conditions {
write!(f, " {when}")?;
}

if let Some(else_result) = else_result {
write!(f, " ELSE {else_result}")?;
}
Expand Down
6 changes: 3 additions & 3 deletions src/ast/spans.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1445,15 +1445,15 @@ impl Spanned for Expr {
Expr::Case {
operand,
conditions,
results,
else_result,
} => union_spans(
operand
.as_ref()
.map(|i| i.span())
.into_iter()
.chain(conditions.iter().map(|i| i.span()))
.chain(results.iter().map(|i| i.span()))
.chain(conditions.iter().flat_map(|case_when| {
[case_when.condition.span(), case_when.result.span()]
}))
.chain(else_result.as_ref().map(|i| i.span())),
),
Expr::Exists { subquery, .. } => subquery.span(),
Expand Down
7 changes: 3 additions & 4 deletions src/parser/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2065,11 +2065,11 @@ impl<'a> Parser<'a> {
self.expect_keyword_is(Keyword::WHEN)?;
}
let mut conditions = vec![];
let mut results = vec![];
loop {
conditions.push(self.parse_expr()?);
let condition = self.parse_expr()?;
self.expect_keyword_is(Keyword::THEN)?;
results.push(self.parse_expr()?);
let result = self.parse_expr()?;
conditions.push(CaseWhen { condition, result });
if !self.parse_keyword(Keyword::WHEN) {
break;
}
Expand All @@ -2083,7 +2083,6 @@ impl<'a> Parser<'a> {
Ok(Expr::Case {
operand,
conditions,
results,
else_result,
})
}
Expand Down
107 changes: 70 additions & 37 deletions tests/sqlparser_common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6539,22 +6539,26 @@ fn parse_searched_case_expr() {
&Case {
operand: None,
conditions: vec![
IsNull(Box::new(Identifier(Ident::new("bar")))),
BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))),
op: Eq,
right: Box::new(Expr::Value(number("0"))),
CaseWhen {
condition: IsNull(Box::new(Identifier(Ident::new("bar")))),
result: Expr::Value(Value::SingleQuotedString("null".to_string())),
},
BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))),
op: GtEq,
right: Box::new(Expr::Value(number("0"))),
CaseWhen {
condition: BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))),
op: Eq,
right: Box::new(Expr::Value(number("0"))),
},
result: Expr::Value(Value::SingleQuotedString("=0".to_string())),
},
CaseWhen {
condition: BinaryOp {
left: Box::new(Identifier(Ident::new("bar"))),
op: GtEq,
right: Box::new(Expr::Value(number("0"))),
},
result: Expr::Value(Value::SingleQuotedString(">=0".to_string())),
},
],
results: vec![
Expr::Value(Value::SingleQuotedString("null".to_string())),
Expr::Value(Value::SingleQuotedString("=0".to_string())),
Expr::Value(Value::SingleQuotedString(">=0".to_string())),
],
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
"<0".to_string()
Expand All @@ -6573,8 +6577,10 @@ fn parse_simple_case_expr() {
assert_eq!(
&Case {
operand: Some(Box::new(Identifier(Ident::new("foo")))),
conditions: vec![Expr::Value(number("1"))],
results: vec![Expr::Value(Value::SingleQuotedString("Y".to_string()))],
conditions: vec![CaseWhen {
condition: Expr::Value(number("1")),
result: Expr::Value(Value::SingleQuotedString("Y".to_string())),
}],
else_result: Some(Box::new(Expr::Value(Value::SingleQuotedString(
"N".to_string()
)))),
Expand Down Expand Up @@ -13734,6 +13740,31 @@ fn test_trailing_commas_in_from() {
);
}

#[test]
#[cfg(feature = "visitor")]
fn test_visit_order() {
let sql = "SELECT CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END";
let stmt = verified_stmt(sql);
let mut visited = vec![];
sqlparser::ast::visit_expressions(&stmt, |expr| {
visited.push(expr.to_string());
core::ops::ControlFlow::<()>::Continue(())
});

assert_eq!(
visited,
[
"CASE a WHEN 1 THEN 2 WHEN 3 THEN 4 ELSE 5 END",
"a",
"1",
"2",
"3",
"4",
"5"
]
);
}

#[test]
fn test_lambdas() {
let dialects = all_dialects_where(|d| d.supports_lambda_functions());
Expand Down Expand Up @@ -13761,28 +13792,30 @@ fn test_lambdas() {
body: Box::new(Expr::Case {
operand: None,
conditions: vec![
Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("p1"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Identifier(Ident::new("p2")))
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("p1"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Identifier(Ident::new("p2")))
},
result: Expr::Value(number("0"))
},
Expr::BinaryOp {
left: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p1"))]
)),
op: BinaryOperator::Lt,
right: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p2"))]
))
}
],
results: vec![
Expr::Value(number("0")),
Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Value(number("1")))
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p1"))]
)),
op: BinaryOperator::Lt,
right: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p2"))]
))
},
result: Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Value(number("1")))
}
}
],
else_result: Some(Box::new(Expr::Value(number("1"))))
Expand Down
65 changes: 65 additions & 0 deletions tests/sqlparser_databricks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,71 @@ fn test_databricks_exists() {
);
}

#[test]
fn test_databricks_lambdas() {
#[rustfmt::skip]
let sql = concat!(
"SELECT array_sort(array('Hello', 'World'), ",
"(p1, p2) -> CASE WHEN p1 = p2 THEN 0 ",
"WHEN reverse(p1) < reverse(p2) THEN -1 ",
"ELSE 1 END)",
);
pretty_assertions::assert_eq!(
SelectItem::UnnamedExpr(call(
"array_sort",
[
call(
"array",
[
Expr::Value(Value::SingleQuotedString("Hello".to_owned())),
Expr::Value(Value::SingleQuotedString("World".to_owned()))
]
),
Expr::Lambda(LambdaFunction {
params: OneOrManyWithParens::Many(vec![Ident::new("p1"), Ident::new("p2")]),
body: Box::new(Expr::Case {
operand: None,
conditions: vec![
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(Expr::Identifier(Ident::new("p1"))),
op: BinaryOperator::Eq,
right: Box::new(Expr::Identifier(Ident::new("p2")))
},
result: Expr::Value(number("0"))
},
CaseWhen {
condition: Expr::BinaryOp {
left: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p1"))]
)),
op: BinaryOperator::Lt,
right: Box::new(call(
"reverse",
[Expr::Identifier(Ident::new("p2"))]
)),
},
result: Expr::UnaryOp {
op: UnaryOperator::Minus,
expr: Box::new(Expr::Value(number("1")))
}
},
],
else_result: Some(Box::new(Expr::Value(number("1"))))
})
})
]
)),
databricks().verified_only_select(sql).projection[0]
);

databricks().verified_expr(
"map_zip_with(map(1, 'a', 2, 'b'), map(1, 'x', 2, 'y'), (k, v1, v2) -> concat(v1, v2))",
);
databricks().verified_expr("transform(array(1, 2, 3), x -> x + 1)");
}

#[test]
fn test_values_clause() {
let values = Values {
Expand Down