Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sql: Implement CASE statements. #146

Merged
merged 2 commits into from
May 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions arroyo-sql-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1038,4 +1038,65 @@ mod tests {
},
String::from("ThoXXs")
);

// test CASE statements
single_test_codegen!(
"match_case_statement_non_nullable",
"CASE WHEN non_nullable_i32 = 1 THEN 'one' WHEN non_nullable_i32 = 2 THEN 'two' ELSE 'other' END",
arroyo_sql::TestStruct {
non_nullable_i32: 1,
..Default::default()
},
String::from("one")
);

single_test_codegen!(
"match_case_statement_nullable",
"CASE WHEN nullable_i32 = 1 THEN 'one' WHEN nullable_i32 = 2 THEN 'two' ELSE 'other' END",
arroyo_sql::TestStruct {
nullable_i32: Some(2),
..Default::default()
},
String::from("two")
);

single_test_codegen!(
"match_case_statement_no_default",
"CASE WHEN non_nullable_i32 = 1 THEN 'one' WHEN non_nullable_i32 = 2 THEN 'two' END",
arroyo_sql::TestStruct {
non_nullable_i32: 3,
..Default::default()
},
None
);

single_test_codegen!(
"search_case_statement_non_nullable",
"CASE non_nullable_i32 when 1 THEN 'one' WHEN 2 THEN 'two' ELSE 'other' END",
arroyo_sql::TestStruct {
non_nullable_i32: 31,
..Default::default()
},
String::from("other")
);

single_test_codegen!(
"seach_case_statement_no_default",
"CASE non_nullable_i32 when 1 THEN 'one' WHEN 2 THEN 'two' END",
arroyo_sql::TestStruct {
non_nullable_i32: 2,
..Default::default()
},
Some(String::from("two"))
);

single_test_codegen!(
"search_case_statement_nullable",
"CASE nullable_i32 when 1 THEN nullable_string WHEN 2 THEN 'two' ELSE 'other' END",
arroyo_sql::TestStruct {
nullable_i32: Some(2),
..Default::default()
},
Some(String::from("two"))
);
}
191 changes: 185 additions & 6 deletions arroyo-sql/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub enum Expression {
Json(JsonExpression),
RustUdf(RustUdfExpression),
WrapType(WrapTypeExpression),
Case(CaseExpression),
}

impl Expression {
Expand Down Expand Up @@ -71,6 +72,18 @@ impl Expression {
Expression::Json(json_function) => json_function.to_syn_expression(),
Expression::RustUdf(t) => t.to_syn_expression(),
Expression::WrapType(t) => t.to_syn_expression(),
Expression::Case(case_expression) => case_expression.to_syn_expression(),
}
}

fn syn_expression_with_nullity(&self, nullity: bool) -> syn::Expr {
let expr = self.to_syn_expression();
match (self.nullable(), nullity) {
(true, true) | (false, false) => expr,
(false, true) => parse_quote!(Some(#expr)),
(true, false) => unreachable!(
"Should not be possible to have a nullable expression with nullity=false"
),
}
}

Expand Down Expand Up @@ -99,6 +112,7 @@ impl Expression {
Expression::Json(json_function) => json_function.return_type(),
Expression::RustUdf(t) => t.return_type(),
Expression::WrapType(t) => t.return_type(),
Expression::Case(case_statement) => case_statement.return_type(),
}
}

Expand Down Expand Up @@ -297,12 +311,31 @@ impl<'a> ExpressionContext<'a> {
expr,
when_then_expr,
else_expr,
}) => bail!(
"case expressions not supported: expr:{:?}, when_then_expr:{:?}, else_expr:{:?}",
expr,
when_then_expr,
else_expr
),
}) => {
let expr = expr
.as_ref()
.map(|e| Ok(Box::new(self.compile_expr(e)?)))
.transpose()?;
let when_then_expr = when_then_expr
.iter()
.map(|(when, then)| {
Ok((
Box::new(self.compile_expr(when)?),
Box::new(self.compile_expr(then)?),
))
})
.collect::<Result<Vec<_>>>()?;
let else_expr = else_expr
.as_ref()
.map(|e| Ok(Box::new(self.compile_expr(e)?)))
.transpose()?;

Ok(Expression::Case(CaseExpression::new(
expr,
when_then_expr,
else_expr,
)))
}
Expr::Cast(datafusion_expr::Cast { expr, data_type }) => Ok(CastExpression::new(
Box::new(self.compile_expr(expr)?),
data_type,
Expand Down Expand Up @@ -2420,3 +2453,149 @@ impl WrapTypeExpression {
self.ret_type.clone()
}
}

#[derive(Debug, Clone)]
pub enum CaseExpression {
// match a single value to multiple potential matches
Match {
value: Box<Expression>,
matches: Vec<(Box<Expression>, Box<Expression>)>,
default: Option<Box<Expression>>,
},
// search for a true expression
When {
condition_pairs: Vec<(Box<Expression>, Box<Expression>)>,
default: Option<Box<Expression>>,
},
}

impl CaseExpression {
fn new(
primary_expr: Option<Box<Expression>>,
when_then_expr: Vec<(Box<Expression>, Box<Expression>)>,
else_expr: Option<Box<Expression>>,
) -> Self {
match primary_expr {
Some(primary_expr) => Self::Match {
value: primary_expr,
matches: when_then_expr,
default: else_expr,
},
None => {
// if there is no primary expression, then it's a when expression
Self::When {
condition_pairs: when_then_expr,
default: else_expr,
}
}
}
}

fn to_syn_expression(&self) -> syn::Expr {
let nullable = self.nullable();
match self {
CaseExpression::Match {
value,
matches,
default,
} => {
// It's easier to have value always be option and then return default if it is None.
// It is possible to have more efficient code when all of the expressions are
// not nullable and the default is not nullable, but it's not worth the complexity.
let value = value.syn_expression_with_nullity(true);
let if_clauses: Vec<syn::ExprIf> = matches
.iter()
.map(|(when_expr, then_expr)| {
let when_expr = when_expr.syn_expression_with_nullity(true);
let then_expr = then_expr.syn_expression_with_nullity(nullable);
parse_quote!(if #when_expr == value { #then_expr })
})
.collect();
let default_expr = default
.as_ref()
.map(|d| d.syn_expression_with_nullity(nullable))
// this is safe because if default is null the result is nullable.
.unwrap_or_else(|| parse_quote!(None));
parse_quote!({
let value = #value;
if value.is_none() {
#default_expr
} else #(#if_clauses else)* {
#default_expr
}
})
}
CaseExpression::When {
condition_pairs,
default,
} => {
let if_clauses: Vec<syn::ExprIf> = condition_pairs
.iter()
.map(|(when_expr, then_expr)| {
let when_expr = when_expr.syn_expression_with_nullity(true);
let then_expr = then_expr.syn_expression_with_nullity(nullable);
parse_quote!(if #when_expr.unwrap_or(false) { #then_expr })
})
.collect();
let default_expr = default
.as_ref()
.map(|d| d.syn_expression_with_nullity(nullable))
// this is safe because if default is null the result is nullable.
.unwrap_or_else(|| parse_quote!(None));
parse_quote!({
#(#if_clauses else)* {
#default_expr
}
})
}
}
}

fn nullable(&self) -> bool {
match self {
CaseExpression::Match {
value: _,
matches: pairs,
default,
}
| CaseExpression::When {
condition_pairs: pairs,
default,
} => {
// if there is a nullable default or it is missing, then it is nullable. Otherwise, it is not nullable
match default {
Some(default) => {
if default.nullable() {
true
} else if pairs
.iter()
.any(|(_when_expr, then_expr)| then_expr.nullable())
{
true
} else {
false
}
}
None => true,
}
}
}
}

fn return_type(&self) -> TypeDef {
match self {
CaseExpression::Match {
value: _,
matches: pairs,
default: _,
}
| CaseExpression::When {
condition_pairs: pairs,
default: _,
} => {
// guaranteed to have at least one pair.
pairs[0].1.return_type().with_nullity(self.nullable())
}
}
}
}