From 5b4460b68fb3eb2caeaecb8696dad8418d6eb245 Mon Sep 17 00:00:00 2001 From: Jackson Newhouse Date: Mon, 29 May 2023 16:43:14 -0700 Subject: [PATCH] sql: Implement CASE statements. --- arroyo-sql-testing/src/lib.rs | 61 +++++++++++ arroyo-sql/src/expressions.rs | 191 ++++++++++++++++++++++++++++++++-- 2 files changed, 246 insertions(+), 6 deletions(-) diff --git a/arroyo-sql-testing/src/lib.rs b/arroyo-sql-testing/src/lib.rs index 8346ef16a..24b1064f4 100644 --- a/arroyo-sql-testing/src/lib.rs +++ b/arroyo-sql-testing/src/lib.rs @@ -1038,4 +1038,65 @@ mod tests { }, String::from("ThoXXs") ); + + // test CASE statements + single_test_codegen!( + "match_case_statement_non_nullable", + "CASE WHEN non_nullable_i32 = 1 THEN 'one' WHEN non_nullable_i32 = 2 THEN 'two' ELSE 'other' END", + arroyo_sql::TestStruct { + non_nullable_i32: 1, + ..Default::default() + }, + String::from("one") + ); + + single_test_codegen!( + "match_case_statement_nullable", + "CASE WHEN nullable_i32 = 1 THEN 'one' WHEN nullable_i32 = 2 THEN 'two' ELSE 'other' END", + arroyo_sql::TestStruct { + nullable_i32: Some(2), + ..Default::default() + }, + String::from("two") + ); + + single_test_codegen!( + "match_case_statement_no_default", + "CASE WHEN non_nullable_i32 = 1 THEN 'one' WHEN non_nullable_i32 = 2 THEN 'two' END", + arroyo_sql::TestStruct { + non_nullable_i32: 3, + ..Default::default() + }, + None + ); + + single_test_codegen!( + "search_case_statement_non_nullable", + "CASE non_nullable_i32 when 1 THEN 'one' WHEN 2 THEN 'two' ELSE 'other' END", + arroyo_sql::TestStruct { + non_nullable_i32: 31, + ..Default::default() + }, + String::from("other") + ); + + single_test_codegen!( + "seach_case_statement_no_default", + "CASE non_nullable_i32 when 1 THEN 'one' WHEN 2 THEN 'two' END", + arroyo_sql::TestStruct { + non_nullable_i32: 2, + ..Default::default() + }, + Some(String::from("two")) + ); + + single_test_codegen!( + "search_case_statement_nullable", + "CASE nullable_i32 when 1 THEN nullable_string WHEN 2 THEN 'two' ELSE 'other' END", + arroyo_sql::TestStruct { + nullable_i32: Some(2), + ..Default::default() + }, + Some(String::from("two")) + ); } diff --git a/arroyo-sql/src/expressions.rs b/arroyo-sql/src/expressions.rs index d077a57d9..11473a32e 100644 --- a/arroyo-sql/src/expressions.rs +++ b/arroyo-sql/src/expressions.rs @@ -41,6 +41,7 @@ pub enum Expression { Json(JsonExpression), RustUdf(RustUdfExpression), WrapType(WrapTypeExpression), + Case(CaseExpression), } impl Expression { @@ -71,6 +72,18 @@ impl Expression { Expression::Json(json_function) => json_function.to_syn_expression(), Expression::RustUdf(t) => t.to_syn_expression(), Expression::WrapType(t) => t.to_syn_expression(), + Expression::Case(case_expression) => case_expression.to_syn_expression(), + } + } + + fn syn_expression_with_nullity(&self, nullity: bool) -> syn::Expr { + let expr = self.to_syn_expression(); + match (self.nullable(), nullity) { + (true, true) | (false, false) => expr, + (false, true) => parse_quote!(Some(#expr)), + (true, false) => unreachable!( + "Should not be possible to have a nullable expression with nullity=false" + ), } } @@ -99,6 +112,7 @@ impl Expression { Expression::Json(json_function) => json_function.return_type(), Expression::RustUdf(t) => t.return_type(), Expression::WrapType(t) => t.return_type(), + Expression::Case(case_statement) => case_statement.return_type(), } } @@ -293,12 +307,31 @@ impl<'a> ExpressionContext<'a> { expr, when_then_expr, else_expr, - }) => bail!( - "case expressions not supported: expr:{:?}, when_then_expr:{:?}, else_expr:{:?}", - expr, - when_then_expr, - else_expr - ), + }) => { + let expr = expr + .as_ref() + .map(|e| Ok(Box::new(self.compile_expr(e)?))) + .transpose()?; + let when_then_expr = when_then_expr + .iter() + .map(|(when, then)| { + Ok(( + Box::new(self.compile_expr(when)?), + Box::new(self.compile_expr(then)?), + )) + }) + .collect::>>()?; + let else_expr = else_expr + .as_ref() + .map(|e| Ok(Box::new(self.compile_expr(e)?))) + .transpose()?; + + Ok(Expression::Case(CaseExpression::new( + expr, + when_then_expr, + else_expr, + ))) + } Expr::Cast(datafusion_expr::Cast { expr, data_type }) => Ok(CastExpression::new( Box::new(self.compile_expr(expr)?), data_type, @@ -2411,3 +2444,149 @@ impl WrapTypeExpression { self.ret_type.clone() } } + +#[derive(Debug, Clone)] +pub enum CaseExpression { + // match a single value to multiple potential matches + Match { + value: Box, + matches: Vec<(Box, Box)>, + default: Option>, + }, + // search for a true expression + When { + condition_pairs: Vec<(Box, Box)>, + default: Option>, + }, +} + +impl CaseExpression { + fn new( + primary_expr: Option>, + when_then_expr: Vec<(Box, Box)>, + else_expr: Option>, + ) -> Self { + match primary_expr { + Some(primary_expr) => Self::Match { + value: primary_expr, + matches: when_then_expr, + default: else_expr, + }, + None => { + // if there is no primary expression, then it's a when expression + Self::When { + condition_pairs: when_then_expr, + default: else_expr, + } + } + } + } + + fn to_syn_expression(&self) -> syn::Expr { + let nullable = self.nullable(); + match self { + CaseExpression::Match { + value, + matches, + default, + } => { + // It's easier to have value always be option and then return default if it is None. + // It is possible to have more efficient code when all of the expressions are + // not nullable and the default is not nullable, but it's not worth the complexity. + let value = value.syn_expression_with_nullity(true); + let if_clauses: Vec = matches + .iter() + .map(|(when_expr, then_expr)| { + let when_expr = when_expr.syn_expression_with_nullity(true); + let then_expr = then_expr.syn_expression_with_nullity(nullable); + parse_quote!(if #when_expr == value { #then_expr }) + }) + .collect(); + let default_expr = default + .as_ref() + .map(|d| d.syn_expression_with_nullity(nullable)) + // this is safe because if default is null the result is nullable. + .unwrap_or_else(|| parse_quote!(None)); + parse_quote!({ + let value = #value; + if value.is_none() { + #default_expr + } else #(#if_clauses else)* { + #default_expr + } + }) + } + CaseExpression::When { + condition_pairs, + default, + } => { + let if_clauses: Vec = condition_pairs + .iter() + .map(|(when_expr, then_expr)| { + let when_expr = when_expr.syn_expression_with_nullity(true); + let then_expr = then_expr.syn_expression_with_nullity(nullable); + parse_quote!(if #when_expr.unwrap_or(false) { #then_expr }) + }) + .collect(); + let default_expr = default + .as_ref() + .map(|d| d.syn_expression_with_nullity(nullable)) + // this is safe because if default is null the result is nullable. + .unwrap_or_else(|| parse_quote!(None)); + parse_quote!({ + #(#if_clauses else)* { + #default_expr + } + }) + } + } + } + + fn nullable(&self) -> bool { + match self { + CaseExpression::Match { + value: _, + matches: pairs, + default, + } + | CaseExpression::When { + condition_pairs: pairs, + default, + } => { + // if there is a nullable default or it is missing, then it is nullable. Otherwise, it is not nullable + match default { + Some(default) => { + if default.nullable() { + true + } else if pairs + .iter() + .any(|(_when_expr, then_expr)| then_expr.nullable()) + { + true + } else { + false + } + } + None => true, + } + } + } + } + + fn return_type(&self) -> TypeDef { + match self { + CaseExpression::Match { + value: _, + matches: pairs, + default: _, + } + | CaseExpression::When { + condition_pairs: pairs, + default: _, + } => { + // guaranteed to have at least one pair. + pairs[0].1.return_type().with_nullity(self.nullable()) + } + } + } +}