Skip to content

Commit

Permalink
remove manual recursion and add a nested test case
Browse files Browse the repository at this point in the history
  • Loading branch information
andygrove committed Aug 25, 2022
1 parent bc3e967 commit 6e0f926
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,8 @@ impl ExprRewriter for PreCastLitExprRewriter {
// traverse the expr by dfs
match &expr {
Expr::BinaryExpr { left, op, right } => {
// dfs visit the left and right expr
let left = self.mutate(*left.clone())?;
let right = self.mutate(*right.clone())?;
let left = left.as_ref().clone();
let right = right.as_ref().clone();
let left_type = left.get_type(&self.schema);
let right_type = right.get_type(&self.schema);
// can't get the data type, just return the expr
Expand Down Expand Up @@ -309,7 +308,7 @@ mod tests {
}

#[test]
fn nested() {
fn aliased() {
let schema = expr_test_schema();
// c1 < INT64(16) -> c1 < cast(INT32(16))
// the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16)
Expand All @@ -318,6 +317,20 @@ mod tests {
assert_eq!(optimize_test(expr_lt, &schema), expected);
}

#[test]
fn nested() {
let schema = expr_test_schema();
// c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32)
// the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32
let expr_lt = col("c1")
.lt(lit(ScalarValue::Int64(Some(16))))
.or(col("c1").gt(lit(ScalarValue::Int64(Some(32)))));
let expected = col("c1")
.lt(lit(ScalarValue::Int32(Some(16))))
.or(col("c1").gt(lit(ScalarValue::Int32(Some(32)))));
assert_eq!(optimize_test(expr_lt, &schema), expected);
}

fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
let mut expr_rewriter = PreCastLitExprRewriter {
schema: schema.clone(),
Expand Down

0 comments on commit 6e0f926

Please sign in to comment.