Skip to content

Commit 808d6ab

Browse files
committed
perf: unwrap cast for comparing ints =/!= strings
1 parent 618880e commit 808d6ab

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,7 +1758,7 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
17581758
// try_cast/cast(expr as data_type) op literal
17591759
Expr::BinaryExpr(BinaryExpr { left, op, right })
17601760
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1761-
info, &left, &right,
1761+
info, &left, &right, op,
17621762
) && op.supports_propagation() =>
17631763
{
17641764
unwrap_cast_in_comparison_for_binary(info, left, right, op)?
@@ -1768,7 +1768,7 @@ impl<S: SimplifyInfo> TreeNodeRewriter for Simplifier<'_, S> {
17681768
// try_cast/cast(expr as data_type) op_swap literal
17691769
Expr::BinaryExpr(BinaryExpr { left, op, right })
17701770
if is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary(
1771-
info, &right, &left,
1771+
info, &right, &left, op,
17721772
) && op.supports_propagation()
17731773
&& op.swap().is_some() =>
17741774
{

datafusion/optimizer/src/simplify_expressions/unwrap_cast.rs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,16 @@ pub(super) fn unwrap_cast_in_comparison_for_binary<S: SimplifyInfo>(
8181
let Ok(expr_type) = info.get_data_type(&expr) else {
8282
return internal_err!("Can't get the data type of the expr {:?}", &expr);
8383
};
84+
85+
if let Some(value) = cast_literal_to_type_with_op(&lit_value, &expr_type, op)
86+
{
87+
return Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr {
88+
left: expr,
89+
op,
90+
right: Box::new(lit(value)),
91+
})));
92+
};
93+
8494
// if the lit_value can be casted to the type of internal_left_expr
8595
// we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
8696
let Some(value) = try_cast_literal_to_type(&lit_value, &expr_type) else {
@@ -106,6 +116,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
106116
info: &S,
107117
expr: &Expr,
108118
literal: &Expr,
119+
op: Operator,
109120
) -> bool {
110121
match (expr, literal) {
111122
(
@@ -125,6 +136,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
125136
return false;
126137
};
127138

139+
if cast_literal_to_type_with_op(lit_val, &expr_type, op).is_some() {
140+
return true;
141+
}
142+
128143
try_cast_literal_to_type(lit_val, &expr_type).is_some()
129144
&& is_supported_type(&expr_type)
130145
&& is_supported_type(&lit_type)
@@ -177,6 +192,33 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist<
177192
true
178193
}
179194

195+
fn cast_literal_to_type_with_op(
196+
lit_value: &ScalarValue,
197+
target_type: &DataType,
198+
op: Operator,
199+
) -> Option<ScalarValue> {
200+
dbg!(lit_value, target_type, op);
201+
match (op, lit_value) {
202+
(
203+
Operator::Eq | Operator::NotEq,
204+
ScalarValue::Utf8(Some(ref str))
205+
| ScalarValue::Utf8View(Some(ref str))
206+
| ScalarValue::LargeUtf8(Some(ref str)),
207+
) => match target_type {
208+
DataType::Int8 => str.parse::<i8>().ok().map(ScalarValue::from),
209+
DataType::Int16 => str.parse::<i16>().ok().map(ScalarValue::from),
210+
DataType::Int32 => str.parse::<i32>().ok().map(ScalarValue::from),
211+
DataType::Int64 => str.parse::<i64>().ok().map(ScalarValue::from),
212+
DataType::UInt8 => str.parse::<u8>().ok().map(ScalarValue::from),
213+
DataType::UInt16 => str.parse::<u16>().ok().map(ScalarValue::from),
214+
DataType::UInt32 => str.parse::<u32>().ok().map(ScalarValue::from),
215+
DataType::UInt64 => str.parse::<u64>().ok().map(ScalarValue::from),
216+
_ => None,
217+
},
218+
_ => None,
219+
}
220+
}
221+
180222
/// Returns true if unwrap_cast_in_comparison supports this data type
181223
fn is_supported_type(data_type: &DataType) -> bool {
182224
is_supported_numeric_type(data_type)
@@ -468,6 +510,10 @@ mod tests {
468510
// the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type
469511
let expr_lt = cast(col("c1"), DataType::Int64).lt(lit(99999999999i64));
470512
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
513+
514+
// cast(c1, UTF8) < 123, only eq/not_eq should be optimized
515+
let expr_lt = cast(col("c1"), DataType::Utf8).lt(lit("123"));
516+
assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
471517
}
472518

473519
#[test]
@@ -496,6 +542,16 @@ mod tests {
496542
let lit_lt_lit = cast(null_i8(), DataType::Int32).lt(lit(12i32));
497543
let expected = null_bool();
498544
assert_eq!(optimize_test(lit_lt_lit, &schema), expected);
545+
546+
// cast(c1, UTF8) = "123" => c1 = 123
547+
let expr_input = cast(col("c1"), DataType::Utf8).eq(lit("123"));
548+
let expected = col("c1").eq(lit(123i32));
549+
assert_eq!(optimize_test(expr_input, &schema), expected);
550+
551+
// cast(c1, UTF8) != "123" => c1 != 123
552+
let expr_input = cast(col("c1"), DataType::Utf8).not_eq(lit("123"));
553+
let expected = col("c1").not_eq(lit(123i32));
554+
assert_eq!(optimize_test(expr_input, &schema), expected);
499555
}
500556

501557
#[test]
@@ -505,6 +561,16 @@ mod tests {
505561
let expr_input = cast(col("c6"), DataType::UInt64).eq(lit(0u64));
506562
let expected = col("c6").eq(lit(0u32));
507563
assert_eq!(optimize_test(expr_input, &schema), expected);
564+
565+
// cast(c6, UTF8) = "123" => c6 = 123
566+
let expr_input = cast(col("c6"), DataType::Utf8).eq(lit("123"));
567+
let expected = col("c6").eq(lit(123u32));
568+
assert_eq!(optimize_test(expr_input, &schema), expected);
569+
570+
// cast(c6, UTF8) != "123" => c6 != 123
571+
let expr_input = cast(col("c6"), DataType::Utf8).not_eq(lit("123"));
572+
let expected = col("c6").not_eq(lit(123u32));
573+
assert_eq!(optimize_test(expr_input, &schema), expected);
508574
}
509575

510576
#[test]

0 commit comments

Comments
 (0)