@@ -81,6 +81,16 @@ pub(super) fn unwrap_cast_in_comparison_for_binary<S: SimplifyInfo>(
8181 let Ok ( expr_type) = info. get_data_type ( & expr) else {
8282 return internal_err ! ( "Can't get the data type of the expr {:?}" , & expr) ;
8383 } ;
84+
85+ if let Some ( value) = cast_literal_to_type_with_op ( & lit_value, & expr_type, op)
86+ {
87+ return Ok ( Transformed :: yes ( Expr :: BinaryExpr ( BinaryExpr {
88+ left : expr,
89+ op,
90+ right : Box :: new ( lit ( value) ) ,
91+ } ) ) ) ;
92+ } ;
93+
8494 // if the lit_value can be casted to the type of internal_left_expr
8595 // we need to unwrap the cast for cast/try_cast expr, and add cast to the literal
8696 let Some ( value) = try_cast_literal_to_type ( & lit_value, & expr_type) else {
@@ -105,6 +115,7 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
105115> (
106116 info : & S ,
107117 expr : & Expr ,
118+ op : Operator ,
108119 literal : & Expr ,
109120) -> bool {
110121 match ( expr, literal) {
@@ -125,6 +136,10 @@ pub(super) fn is_cast_expr_and_support_unwrap_cast_in_comparison_for_binary<
125136 return false ;
126137 } ;
127138
139+ if cast_literal_to_type_with_op ( lit_val, & expr_type, op) . is_some ( ) {
140+ return true ;
141+ }
142+
128143 try_cast_literal_to_type ( lit_val, & expr_type) . is_some ( )
129144 && is_supported_type ( & expr_type)
130145 && is_supported_type ( & lit_type)
@@ -215,6 +230,52 @@ fn is_supported_dictionary_type(data_type: &DataType) -> bool {
215230 DataType :: Dictionary ( _, inner) if is_supported_type( inner) )
216231}
217232
233+ ///// Tries to move a cast from an expression (such as column) to the literal other side of a comparison operator./
234+ ///
235+ /// Specifically, rewrites
236+ /// ```sql
237+ /// cast(col) <op> <literal>
238+ /// ```
239+ ///
240+ /// To
241+ ///
242+ /// ```sql
243+ /// col <op> cast(<literal>)
244+ /// col <op> <casted_literal>
245+ /// ```
246+ fn cast_literal_to_type_with_op (
247+ lit_value : & ScalarValue ,
248+ target_type : & DataType ,
249+ op : Operator ,
250+ ) -> Option < ScalarValue > {
251+ match ( op, lit_value) {
252+ (
253+ Operator :: Eq | Operator :: NotEq ,
254+ ScalarValue :: Utf8 ( Some ( _) )
255+ | ScalarValue :: Utf8View ( Some ( _) )
256+ | ScalarValue :: LargeUtf8 ( Some ( _) ) ,
257+ ) => {
258+ // Only try for integer types (TODO can we do this for other types
259+ // like timestamps)?
260+ use DataType :: * ;
261+ if matches ! (
262+ target_type,
263+ Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32 | UInt64
264+ ) {
265+ let casted = lit_value. cast_to ( target_type) . ok ( ) ?;
266+ let round_tripped = casted. cast_to ( & lit_value. data_type ( ) ) . ok ( ) ?;
267+ if lit_value != & round_tripped {
268+ return None ;
269+ }
270+ Some ( casted)
271+ } else {
272+ None
273+ }
274+ }
275+ _ => None ,
276+ }
277+ }
278+
218279/// Convert a literal value from one data type to another
219280pub ( super ) fn try_cast_literal_to_type (
220281 lit_value : & ScalarValue ,
@@ -468,6 +529,24 @@ mod tests {
468529 // the 99999999999 is not within the range of MAX(int32) and MIN(int32), we don't cast the lit(99999999999) to int32 type
469530 let expr_lt = cast ( col ( "c1" ) , DataType :: Int64 ) . lt ( lit ( 99999999999i64 ) ) ;
470531 assert_eq ! ( optimize_test( expr_lt. clone( ) , & schema) , expr_lt) ;
532+
533+ // cast(c1, UTF8) < '123', only eq/not_eq should be optimized
534+ let expr_lt = cast ( col ( "c1" ) , DataType :: Utf8 ) . lt ( lit ( "123" ) ) ;
535+ assert_eq ! ( optimize_test( expr_lt. clone( ) , & schema) , expr_lt) ;
536+
537+ // cast(c1, UTF8) = '0123', cast(cast('0123', Int32), UTF8) != '0123', so '0123' should not
538+ // be casted
539+ let expr_lt = cast ( col ( "c1" ) , DataType :: Utf8 ) . lt ( lit ( "0123" ) ) ;
540+ assert_eq ! ( optimize_test( expr_lt. clone( ) , & schema) , expr_lt) ;
541+
542+ // cast(c1, UTF8) = 'not a number', should not be able to cast to column type
543+ let expr_input = cast ( col ( "c1" ) , DataType :: Utf8 ) . eq ( lit ( "not a number" ) ) ;
544+ assert_eq ! ( optimize_test( expr_input. clone( ) , & schema) , expr_input) ;
545+
546+ // cast(c1, UTF8) = '99999999999', where '99999999999' does not fit into int32, so it will
547+ // not be optimized to integer comparison
548+ let expr_input = cast ( col ( "c1" ) , DataType :: Utf8 ) . eq ( lit ( "99999999999" ) ) ;
549+ assert_eq ! ( optimize_test( expr_input. clone( ) , & schema) , expr_input) ;
471550 }
472551
473552 #[ test]
@@ -496,6 +575,21 @@ mod tests {
496575 let lit_lt_lit = cast ( null_i8 ( ) , DataType :: Int32 ) . lt ( lit ( 12i32 ) ) ;
497576 let expected = null_bool ( ) ;
498577 assert_eq ! ( optimize_test( lit_lt_lit, & schema) , expected) ;
578+
579+ // cast(c1, UTF8) = '123' => c1 = 123
580+ let expr_input = cast ( col ( "c1" ) , DataType :: Utf8 ) . eq ( lit ( "123" ) ) ;
581+ let expected = col ( "c1" ) . eq ( lit ( 123i32 ) ) ;
582+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
583+
584+ // cast(c1, UTF8) != '123' => c1 != 123
585+ let expr_input = cast ( col ( "c1" ) , DataType :: Utf8 ) . not_eq ( lit ( "123" ) ) ;
586+ let expected = col ( "c1" ) . not_eq ( lit ( 123i32 ) ) ;
587+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
588+
589+ // cast(c1, UTF8) = NULL => c1 = NULL
590+ let expr_input = cast ( col ( "c1" ) , DataType :: Utf8 ) . eq ( lit ( ScalarValue :: Utf8 ( None ) ) ) ;
591+ let expected = col ( "c1" ) . eq ( lit ( ScalarValue :: Int32 ( None ) ) ) ;
592+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
499593 }
500594
501595 #[ test]
@@ -505,6 +599,16 @@ mod tests {
505599 let expr_input = cast ( col ( "c6" ) , DataType :: UInt64 ) . eq ( lit ( 0u64 ) ) ;
506600 let expected = col ( "c6" ) . eq ( lit ( 0u32 ) ) ;
507601 assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
602+
603+ // cast(c6, UTF8) = "123" => c6 = 123
604+ let expr_input = cast ( col ( "c6" ) , DataType :: Utf8 ) . eq ( lit ( "123" ) ) ;
605+ let expected = col ( "c6" ) . eq ( lit ( 123u32 ) ) ;
606+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
607+
608+ // cast(c6, UTF8) != "123" => c6 != 123
609+ let expr_input = cast ( col ( "c6" ) , DataType :: Utf8 ) . not_eq ( lit ( "123" ) ) ;
610+ let expected = col ( "c6" ) . not_eq ( lit ( 123u32 ) ) ;
611+ assert_eq ! ( optimize_test( expr_input, & schema) , expected) ;
508612 }
509613
510614 #[ test]
0 commit comments