Skip to content

Commit

Permalink
Use ExprRewriter in pre_cast_lit_in_comparison (#3260)
Browse files Browse the repository at this point in the history
* Use ExprRewriter in pre_cast_lit_in_comparison.rs

* remove manual recursion and add a nested test case
  • Loading branch information
andygrove authored Aug 25, 2022
1 parent 92110dd commit 5ee52d0
Showing 1 changed file with 107 additions and 65 deletions.
172 changes: 107 additions & 65 deletions datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator};

Expand Down Expand Up @@ -74,79 +75,92 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
.collect::<Result<Vec<_>>>()?;

let schema = plan.schema();

let mut expr_rewriter = PreCastLitExprRewriter {
schema: schema.clone(),
};

let new_exprs = plan
.expressions()
.into_iter()
.map(|expr| visit_expr(expr, schema))
.map(|expr| expr.rewrite(&mut expr_rewriter))
.collect::<Result<Vec<_>>>()?;

from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
}

// Visit all type of expr, if the current has child expr, the child expr needed to visit first.
fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
// traverse the expr by dfs
match &expr {
Expr::BinaryExpr { left, op, right } => {
// dfs visit the left and right expr
let left = visit_expr(*left.clone(), schema)?;
let right = visit_expr(*right.clone(), schema)?;
let left_type = left.get_type(schema);
let right_type = right.get_type(schema);
// can't get the data type, just return the expr
if left_type.is_err() || right_type.is_err() {
return Ok(expr.clone());
}
let left_type = left_type.unwrap();
let right_type = right_type.unwrap();
if !left_type.eq(&right_type)
&& is_support_data_type(&left_type)
&& is_support_data_type(&right_type)
&& is_comparison_op(op)
{
match (&left, &right) {
(Expr::Literal(_), Expr::Literal(_)) => {
// do nothing
}
(Expr::Literal(left_lit_value), _)
if can_integer_literal_cast_to_type(
left_lit_value,
&right_type,
)? =>
{
// cast the left literal to the right type
return Ok(binary_expr(
cast_to_other_scalar_expr(left_lit_value, &right_type)?,
*op,
right,
));
}
(_, Expr::Literal(right_lit_value))
if can_integer_literal_cast_to_type(
right_lit_value,
&left_type,
)
.unwrap() =>
{
// cast the right literal to the left type
return Ok(binary_expr(
left,
*op,
cast_to_other_scalar_expr(right_lit_value, &left_type)?,
));
}
(_, _) => {
// do nothing
}
};
struct PreCastLitExprRewriter {
schema: DFSchemaRef,
}

impl ExprRewriter for PreCastLitExprRewriter {
fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
Ok(RewriteRecursion::Continue)
}

fn mutate(&mut self, expr: Expr) -> Result<Expr> {
// traverse the expr by dfs
match &expr {
Expr::BinaryExpr { left, op, right } => {
let left = left.as_ref().clone();
let right = right.as_ref().clone();
let left_type = left.get_type(&self.schema);
let right_type = right.get_type(&self.schema);
// can't get the data type, just return the expr
if left_type.is_err() || right_type.is_err() {
return Ok(expr.clone());
}
let left_type = left_type?;
let right_type = right_type?;
if !left_type.eq(&right_type)
&& is_support_data_type(&left_type)
&& is_support_data_type(&right_type)
&& is_comparison_op(op)
{
match (&left, &right) {
(Expr::Literal(_), Expr::Literal(_)) => {
// do nothing
}
(Expr::Literal(left_lit_value), _)
if can_integer_literal_cast_to_type(
left_lit_value,
&right_type,
)? =>
{
// cast the left literal to the right type
return Ok(binary_expr(
cast_to_other_scalar_expr(left_lit_value, &right_type)?,
*op,
right,
));
}
(_, Expr::Literal(right_lit_value))
if can_integer_literal_cast_to_type(
right_lit_value,
&left_type,
)
.unwrap() =>
{
// cast the right literal to the left type
return Ok(binary_expr(
left,
*op,
cast_to_other_scalar_expr(right_lit_value, &left_type)?,
));
}
(_, _) => {
// do nothing
}
};
}
// return the new binary op
Ok(binary_expr(left, *op, right))
}
// return the new binary op
Ok(binary_expr(left, *op, right))
// TODO: optimize in list
// Expr::InList { .. } => {}
// TODO: handle other expr type and dfs visit them
_ => Ok(expr),
}
// TODO: optimize in list
// Expr::InList { .. } => {}
// TODO: handle other expr type and dfs visit them
_ => Ok(expr),
}
}

Expand Down Expand Up @@ -245,9 +259,10 @@ fn can_integer_literal_cast_to_type(

#[cfg(test)]
mod tests {
use crate::pre_cast_lit_in_comparison::visit_expr;
use crate::pre_cast_lit_in_comparison::PreCastLitExprRewriter;
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue};
use datafusion_expr::expr_rewriter::ExprRewritable;
use datafusion_expr::{col, lit, Expr};
use std::collections::HashMap;
use std::sync::Arc;
Expand Down Expand Up @@ -292,8 +307,35 @@ mod tests {
assert_eq!(optimize_test(c1_lt_lit_null, &schema), expected);
}

#[test]
fn aliased() {
let schema = expr_test_schema();
// c1 < INT64(16) -> c1 < cast(INT32(16))
// the 16 is within the range of MAX(int32) and MIN(int32), we can cast the 16 to int32(16)
let expr_lt = col("c1").lt(lit(ScalarValue::Int64(Some(16)))).alias("x");
let expected = col("c1").lt(lit(ScalarValue::Int32(Some(16)))).alias("x");
assert_eq!(optimize_test(expr_lt, &schema), expected);
}

#[test]
fn nested() {
let schema = expr_test_schema();
// c1 < INT64(16) OR c1 > INT64(32) -> c1 < INT32(16) OR c1 > INT32(32)
// the 16 and 32 are within the range of MAX(int32) and MIN(int32), we can cast them to int32
let expr_lt = col("c1")
.lt(lit(ScalarValue::Int64(Some(16))))
.or(col("c1").gt(lit(ScalarValue::Int64(Some(32)))));
let expected = col("c1")
.lt(lit(ScalarValue::Int32(Some(16))))
.or(col("c1").gt(lit(ScalarValue::Int32(Some(32)))));
assert_eq!(optimize_test(expr_lt, &schema), expected);
}

fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
visit_expr(expr, schema).unwrap()
let mut expr_rewriter = PreCastLitExprRewriter {
schema: schema.clone(),
};
expr.rewrite(&mut expr_rewriter).unwrap()
}

fn expr_test_schema() -> DFSchemaRef {
Expand Down

0 comments on commit 5ee52d0

Please sign in to comment.