@@ -81,6 +81,16 @@ pub(super) fn unwrap_cast_in_comparison_for_binary<S: SimplifyInfo>(
8181 let Ok ( expr_type) = info. get_data_type ( & expr) else {
8282 return internal_err ! ( "Can't get the data type of the expr {:?}" , & expr) ;
8383 } ;
84+
85+ if let Some ( value) = cast_literal_to_type_with_op ( & lit_value, & expr_type, op)
86+ {
87+ return Ok ( Transformed :: yes ( Expr :: BinaryExpr ( BinaryExpr {
88+ left : expr,
89+ op,
90+ right : Box :: new ( lit ( value) ) ,
91+ } ) ) ) ;
92+ } ;
93+
8494 // if the lit_value can be casted to the type of internal_left_expr
8595 // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
8696 let Some ( value) = try_cast_literal_to_type ( & lit_value, & expr_type) else {
@@ -106,6 +116,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
106116 info : & S ,
107117 expr : & Expr ,
108118 literal : & Expr ,
119+ op : Operator ,
109120) -> bool {
110121 match ( expr, literal) {
111122 (
@@ -125,6 +136,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
125136 return false ;
126137 } ;
127138
139+ if cast_literal_to_type_with_op ( lit_val, & expr_type, op) . is_some ( ) {
140+ return true ;
141+ }
142+
128143 try_cast_literal_to_type ( lit_val, & expr_type) . is_some ( )
129144 && is_supported_type ( & expr_type)
130145 && is_supported_type ( & lit_type)
@@ -177,6 +192,33 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_inlist<
177192 true
178193}
179194
195+ fn cast_literal_to_type_with_op (
196+ lit_value : & ScalarValue ,
197+ target_type : & DataType ,
198+ op : Operator ,
199+ ) -> Option < ScalarValue > {
200+ dbg ! ( lit_value, target_type, op) ;
201+ match ( op, lit_value) {
202+ (
203+ Operator :: Eq | Operator :: NotEq ,
204+ ScalarValue :: Utf8 ( Some ( ref str) )
205+ | ScalarValue :: Utf8View ( Some ( ref str) )
206+ | ScalarValue :: LargeUtf8 ( Some ( ref str) ) ,
207+ ) => match target_type {
208+ DataType :: Int8 => str. parse :: < i8 > ( ) . ok ( ) . map ( ScalarValue :: from) ,
209+ DataType :: Int16 => str. parse :: < i16 > ( ) . ok ( ) . map ( ScalarValue :: from) ,
210+ DataType :: Int32 => str. parse :: < i32 > ( ) . ok ( ) . map ( ScalarValue :: from) ,
211+ DataType :: Int64 => str. parse :: < i64 > ( ) . ok ( ) . map ( ScalarValue :: from) ,
212+ DataType :: UInt8 => str. parse :: < u8 > ( ) . ok ( ) . map ( ScalarValue :: from) ,
213+ DataType :: UInt16 => str. parse :: < u16 > ( ) . ok ( ) . map ( ScalarValue :: from) ,
214+ DataType :: UInt32 => str. parse :: < u32 > ( ) . ok ( ) . map ( ScalarValue :: from) ,
215+ DataType :: UInt64 => str. parse :: < u64 > ( ) . ok ( ) . map ( ScalarValue :: from) ,
216+ _ => None ,
217+ } ,
218+ _ => None ,
219+ }
220+ }
221+
180222/// Returns true if unwrap_cast_in_comparison supports this data type
181223fn is_supported_type ( data_type : & DataType ) -> bool {
182224 is_supported_numeric_type ( data_type)
@@ -468,6 +510,10 @@ mod tests {
468510 // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type
469511 let expr_lt = cast ( col ( "c1" ) , DataType :: Int64 ) . lt ( lit ( 99999999999i64 ) ) ;
470512 assert_eq ! ( optimize_test( expr_lt. clone( ) , & schema) , expr_lt) ;
513+
514+ // cast(c1, UTF8) < 123, only eq/not_eq should be optimized
515+ let expr_lt = cast ( col ( "c1" ) , DataType :: Utf8 ) . lt ( lit ( "123" ) ) ;
516+ assert_eq ! ( optimize_test( expr_lt. clone( ) , & schema) , expr_lt) ;
471517 }
472518
473519 #[ test]
@@ -496,6 +542,16 @@ mod tests {
496542 let lit_lt_lit = cast ( null_i8 ( ) , DataType :: Int32 ) . lt ( lit ( 12i32 ) ) ;
497543 let expected = null_bool ( ) ;
498544 assert_eq ! ( optimize_test( lit_lt_lit, & schema) , expected) ;
545+
546+ // cast(c1, UTF8) = "123" => c1 = 123
547+ let expr_input = cast ( col ( "c1" ) , DataType :: Utf8 ) . eq ( lit ( "123" ) ) ;
548+ let expected = col ( "c1" ) . eq ( lit ( 123i32 ) ) ;
549+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
550+
551+ // cast(c1, UTF8) != "123" => c1 != 123
552+ let expr_input = cast ( col ( "c1" ) , DataType :: Utf8 ) . not_eq ( lit ( "123" ) ) ;
553+ let expected = col ( "c1" ) . not_eq ( lit ( 123i32 ) ) ;
554+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
499555 }
500556
501557 #[ test]
@@ -505,6 +561,16 @@ mod tests {
505561 let expr_input = cast ( col ( "c6" ) , DataType :: UInt64 ) . eq ( lit ( 0u64 ) ) ;
506562 let expected = col ( "c6" ) . eq ( lit ( 0u32 ) ) ;
507563 assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
564+
565+ // cast(c6, UTF8) = "123" => c6 = 123
566+ let expr_input = cast ( col ( "c6" ) , DataType :: Utf8 ) . eq ( lit ( "123" ) ) ;
567+ let expected = col ( "c6" ) . eq ( lit ( 123u32 ) ) ;
568+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
569+
570+ // cast(c6, UTF8) != "123" => c6 != 123
571+ let expr_input = cast ( col ( "c6" ) , DataType :: Utf8 ) . not_eq ( lit ( "123" ) ) ;
572+ let expected = col ( "c6" ) . not_eq ( lit ( 123u32 ) ) ;
573+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
508574 }
509575
510576 #[ test]
0 commit comments