Skip to content
7 changes: 4 additions & 3 deletions datafusion/core/tests/sql/explain_analyze.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,8 @@ async fn csv_explain_verbose() {
async fn csv_explain_inlist_verbose() {
let ctx = SessionContext::new();
register_aggregate_csv_by_sql(&ctx).await;
let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 in (1,2,4)";
// Inlist len <=3 case will be transformed to OR List so we test with len=4
let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 in (1,2,4,5)";
let actual = execute(&ctx, sql).await;

// Optimized by PreCastLitInComparisonExpressions rule
Expand All @@ -368,12 +369,12 @@ async fn csv_explain_inlist_verbose() {
// before optimization (Int64 literals)
assert_contains!(
&actual,
"aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4)])"
"aggregate_test_100.c2 IN ([Int64(1), Int64(2), Int64(4), Int64(5)])"
);
// after optimization (casted to Int8)
assert_contains!(
&actual,
"aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4)])"
"aggregate_test_100.c2 IN ([Int8(1), Int8(2), Int8(4), Int8(5)])"
);
}

Expand Down
1 change: 0 additions & 1 deletion datafusion/optimizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ pub mod replace_distinct_aggregate;
pub mod scalar_subquery_to_join;
pub mod simplify_expressions;
pub mod single_distinct_to_groupby;
pub mod unwrap_cast_in_comparison;
pub mod utils;

#[cfg(test)]
Expand Down
3 changes: 0 additions & 3 deletions datafusion/optimizer/src/optimizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ use crate::replace_distinct_aggregate::ReplaceDistinctWithAggregate;
use crate::scalar_subquery_to_join::ScalarSubqueryToJoin;
use crate::simplify_expressions::SimplifyExpressions;
use crate::single_distinct_to_groupby::SingleDistinctToGroupBy;
use crate::unwrap_cast_in_comparison::UnwrapCastInComparison;
use crate::utils::log_plan;

/// `OptimizerRule`s transforms one [`LogicalPlan`] into another which
Expand Down Expand Up @@ -243,7 +242,6 @@ impl Optimizer {
let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
Arc::new(EliminateNestedUnion::new()),
Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is great -- it will also reduce the number of times the entire plan tree gets walked/massaged

Arc::new(ReplaceDistinctWithAggregate::new()),
Arc::new(EliminateJoin::new()),
Arc::new(DecorrelatePredicateSubquery::new()),
Expand All @@ -266,7 +264,6 @@ impl Optimizer {
// The previous optimizations added expressions and projections,
// that might benefit from the following rules
Arc::new(SimplifyExpressions::new()),
Arc::new(UnwrapCastInComparison::new()),
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateGroupByConstant::new()),
Arc::new(OptimizeProjections::new()),
Expand Down
92 changes: 90 additions & 2 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ use datafusion_common::{
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
};
use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, ScalarValue};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
and, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, Operator, Volatility,
WindowFunctionDefinition,
Expand All @@ -42,14 +41,23 @@ use datafusion_expr::{
expr::{InList, InSubquery, WindowFunction},
utils::{iter_conjunction, iter_conjunction_owned},
};
use datafusion_expr::{simplify::ExprSimplifyResult, Cast, TryCast};
use datafusion_physical_expr::{create_physical_expr, execution_props::ExecutionProps};

use super::inlist_simplifier::ShortenInListSimplifier;
use super::utils::*;
use crate::analyzer::type_coercion::TypeCoercionRewriter;
use crate::simplify_expressions::guarantees::GuaranteeRewriter;
use crate::simplify_expressions::regex::simplify_regex_expr;
use crate::simplify_expressions::unwrap_cast::{
is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary,
is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist,
unwrap_cast_in_comparison_for_binary,
};
use crate::simplify_expressions::SimplifyInfo;
use crate::{
analyzer::type_coercion::TypeCoercionRewriter,
simplify_expressions::unwrap_cast::try_cast_literal_to_type,
};
use indexmap::IndexSet;
use regex::Regex;

Expand Down Expand Up @@ -1742,6 +1750,86 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
}
}

// =======================================
// unwrap_cast_in_comparison
// =======================================
//
// For case:
// try_cast/cast(expr as data_type) op literal
Expr::BinaryExpr(BinaryExpr { left, op, right })
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
info, &left, &right,
) && op.supports_propagation() =>
{
unwrap_cast_in_comparison_for_binary(info, left, right, op)?
}
// literal op try_cast/cast(expr as data_type)
// -->
// try_cast/cast(expr as data_type) op_swap literal
Expr::BinaryExpr(BinaryExpr { left, op, right })
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
info, &right, &left,
) && op.supports_propagation()
&& op.swap().is_some() =>
{
unwrap_cast_in_comparison_for_binary(
info,
right,
left,
op.swap().unwrap(),
)?
}
// For case:
// try_cast/cast(expr as left_type) in (expr1,expr2,expr3)
Expr::InList(InList {
expr: mut left,
list,
negated,
}) if is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist(
info, &left, &list,
) =>
{
let (Expr::TryCast(TryCast {
expr: left_expr, ..
})
| Expr::Cast(Cast {
expr: left_expr, ..
})) = left.as_mut()
else {
return internal_err!("Expect cast expr, but got {:?}", left)?;
};

let expr_type = info.get_data_type(left_expr)?;
let right_exprs = list
.into_iter()
.map(|right| {
match right {
Expr::Literal(right_lit_value) => {
// if the right_lit_value can be casted to the type of internal_left_expr
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
let Some(value) = try_cast_literal_to_type(&right_lit_value, &expr_type) else {
internal_err!(
"Can't cast the list expr {:?} to type {:?}",
right_lit_value, &expr_type
)?
};
Ok(lit(value))
}
other_expr => internal_err!(
"Only support literal expr to optimize, but the expr is {:?}",
&other_expr
),
}
})
.collect::<Result<Vec<_>>>()?;

Transformed::yes(Expr::InList(InList {
expr: std::mem::take(left_expr),
list: right_exprs,
negated,
}))
}

// no additional rewrites possible
expr => Transformed::no(expr),
})
Expand Down
1 change: 1 addition & 0 deletions datafusion/optimizer/src/simplify_expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ mod guarantees;
mod inlist_simplifier;
mod regex;
pub mod simplify_exprs;
mod unwrap_cast;
mod utils;

// backwards compatibility
Expand Down
Loading