1717
1818//! Physical expression schema rewriting utilities
1919
20- use std:: sync:: Arc ;
2120use std:: cmp:: Ordering ;
21+ use std:: sync:: Arc ;
2222
2323use arrow:: compute:: can_cast_types;
2424use arrow:: datatypes:: {
@@ -230,7 +230,9 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
230230 left. as_any ( ) . downcast_ref :: < CastExpr > ( ) ,
231231 right. as_any ( ) . downcast_ref :: < Literal > ( ) ,
232232 ) {
233- if let Some ( optimized) = self . unwrap_cast_with_literal ( cast_expr, literal, * op) ? {
233+ if let Some ( optimized) =
234+ self . unwrap_cast_with_literal ( cast_expr, literal, * op) ?
235+ {
234236 return Ok ( Some ( Arc :: new ( BinaryExpr :: new (
235237 optimized. 0 ,
236238 * op,
@@ -244,7 +246,9 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
244246 left. as_any ( ) . downcast_ref :: < Literal > ( ) ,
245247 right. as_any ( ) . downcast_ref :: < CastExpr > ( ) ,
246248 ) {
247- if let Some ( optimized) = self . unwrap_cast_with_literal ( cast_expr, literal, * op) ? {
249+ if let Some ( optimized) =
250+ self . unwrap_cast_with_literal ( cast_expr, literal, * op) ?
251+ {
248252 return Ok ( Some ( Arc :: new ( BinaryExpr :: new (
249253 optimized. 1 ,
250254 * op,
@@ -265,32 +269,36 @@ impl<'a> PhysicalExprSchemaRewriter<'a> {
265269 ) -> Result < Option < ( Arc < dyn PhysicalExpr > , Arc < dyn PhysicalExpr > ) > > {
266270 // Get the inner expression (what's being cast)
267271 let inner_expr = cast_expr. expr ( ) ;
268-
272+
269273 // Handle the case where inner expression might be another cast (due to schema rewriting)
270274 // This can happen when the schema rewriter adds a cast to a column, and then we have
271275 // an original cast on top of that.
272- let ( final_inner_expr, column) = if let Some ( inner_cast) = inner_expr. as_any ( ) . downcast_ref :: < CastExpr > ( ) {
273- // We have a nested cast, check if the inner cast's expression is a column
274- let inner_inner_expr = inner_cast. expr ( ) ;
275- if let Some ( col) = inner_inner_expr. as_any ( ) . downcast_ref :: < Column > ( ) {
276- ( inner_inner_expr, col)
276+ let ( final_inner_expr, column) =
277+ if let Some ( inner_cast) = inner_expr. as_any ( ) . downcast_ref :: < CastExpr > ( ) {
278+ // We have a nested cast, check if the inner cast's expression is a column
279+ let inner_inner_expr = inner_cast. expr ( ) ;
280+ if let Some ( col) = inner_inner_expr. as_any ( ) . downcast_ref :: < Column > ( ) {
281+ ( inner_inner_expr, col)
282+ } else {
283+ return Ok ( None ) ;
284+ }
285+ } else if let Some ( col) = inner_expr. as_any ( ) . downcast_ref :: < Column > ( ) {
286+ ( inner_expr, col)
277287 } else {
278288 return Ok ( None ) ;
279- }
280- } else if let Some ( col) = inner_expr. as_any ( ) . downcast_ref :: < Column > ( ) {
281- ( inner_expr, col)
282- } else {
283- return Ok ( None ) ;
284- } ;
289+ } ;
285290
286291 // Get the column's data type from the physical schema
287- let column_data_type = match self . physical_file_schema . field_with_name ( column. name ( ) ) {
288- Ok ( field) => field. data_type ( ) ,
289- Err ( _) => return Ok ( None ) , // Column not found, can't optimize
290- } ;
292+ let column_data_type =
293+ match self . physical_file_schema . field_with_name ( column. name ( ) ) {
294+ Ok ( field) => field. data_type ( ) ,
295+ Err ( _) => return Ok ( None ) , // Column not found, can't optimize
296+ } ;
291297
292298 // Try to cast the literal to the column's data type
293- if let Some ( casted_literal) = try_cast_literal_to_type ( literal. value ( ) , column_data_type, op) {
299+ if let Some ( casted_literal) =
300+ try_cast_literal_to_type ( literal. value ( ) , column_data_type, op)
301+ {
294302 return Ok ( Some ( (
295303 Arc :: clone ( final_inner_expr) ,
296304 expressions:: lit ( casted_literal) ,
@@ -323,7 +331,6 @@ fn cast_literal_to_type_with_op(
323331 target_type : & DataType ,
324332 op : Operator ,
325333) -> Option < ScalarValue > {
326-
327334 match ( op, lit_value) {
328335 (
329336 Operator :: Eq | Operator :: NotEq ,
@@ -754,22 +761,27 @@ mod tests {
754761 let column_expr = Arc :: new ( Column :: new ( "a" , 0 ) ) ;
755762 let cast_expr = Arc :: new ( CastExpr :: new ( column_expr, DataType :: Int64 , None ) ) ;
756763 let literal_expr = expressions:: lit ( ScalarValue :: Int64 ( Some ( 123 ) ) ) ;
757- let binary_expr = Arc :: new ( BinaryExpr :: new (
758- cast_expr,
759- Operator :: Eq ,
760- literal_expr,
761- ) ) ;
764+ let binary_expr =
765+ Arc :: new ( BinaryExpr :: new ( cast_expr, Operator :: Eq , literal_expr) ) ;
762766
763767 let result = rewriter. rewrite ( binary_expr. clone ( ) as Arc < dyn PhysicalExpr > ) ?;
764768
765769 // The result should be a binary expression with the cast unwrapped
766770 let result_binary = result. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ;
767-
771+
768772 // Left side should be the original column (no cast)
769- assert ! ( result_binary. left( ) . as_any( ) . downcast_ref:: <Column >( ) . is_some( ) ) ;
770-
773+ assert ! ( result_binary
774+ . left( )
775+ . as_any( )
776+ . downcast_ref:: <Column >( )
777+ . is_some( ) ) ;
778+
771779 // Right side should be a literal with the value cast to Int32
772- let right_literal = result_binary. right ( ) . as_any ( ) . downcast_ref :: < Literal > ( ) . unwrap ( ) ;
780+ let right_literal = result_binary
781+ . right ( )
782+ . as_any ( )
783+ . downcast_ref :: < Literal > ( )
784+ . unwrap ( ) ;
773785 assert_eq ! ( * right_literal. value( ) , ScalarValue :: Int32 ( Some ( 123 ) ) ) ;
774786
775787 Ok ( ( ) )
@@ -787,23 +799,28 @@ mod tests {
787799 let literal_expr = expressions:: lit ( ScalarValue :: Int64 ( Some ( 123 ) ) ) ;
788800 let column_expr = Arc :: new ( Column :: new ( "a" , 0 ) ) ;
789801 let cast_expr = Arc :: new ( CastExpr :: new ( column_expr, DataType :: Int64 , None ) ) ;
790- let binary_expr = Arc :: new ( BinaryExpr :: new (
791- literal_expr,
792- Operator :: Eq ,
793- cast_expr,
794- ) ) ;
802+ let binary_expr =
803+ Arc :: new ( BinaryExpr :: new ( literal_expr, Operator :: Eq , cast_expr) ) ;
795804
796805 let result = rewriter. rewrite ( binary_expr) ?;
797806
798807 // The result should be a binary expression with the cast unwrapped
799808 let result_binary = result. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ;
800-
809+
801810 // Left side should be a literal with the value cast to Int32
802- let left_literal = result_binary. left ( ) . as_any ( ) . downcast_ref :: < Literal > ( ) . unwrap ( ) ;
811+ let left_literal = result_binary
812+ . left ( )
813+ . as_any ( )
814+ . downcast_ref :: < Literal > ( )
815+ . unwrap ( ) ;
803816 assert_eq ! ( * left_literal. value( ) , ScalarValue :: Int32 ( Some ( 123 ) ) ) ;
804-
817+
805818 // Right side should be the original column (no cast)
806- assert ! ( result_binary. right( ) . as_any( ) . downcast_ref:: <Column >( ) . is_some( ) ) ;
819+ assert ! ( result_binary
820+ . right( )
821+ . as_any( )
822+ . downcast_ref:: <Column >( )
823+ . is_some( ) ) ;
807824
808825 Ok ( ( ) )
809826 }
@@ -820,22 +837,27 @@ mod tests {
820837 let column_expr = Arc :: new ( Column :: new ( "a" , 0 ) ) ;
821838 let cast_expr = Arc :: new ( CastExpr :: new ( column_expr, DataType :: Utf8 , None ) ) ;
822839 let literal_expr = expressions:: lit ( ScalarValue :: Utf8 ( Some ( "123" . to_string ( ) ) ) ) ;
823- let binary_expr = Arc :: new ( BinaryExpr :: new (
824- cast_expr,
825- Operator :: Eq ,
826- literal_expr,
827- ) ) ;
840+ let binary_expr =
841+ Arc :: new ( BinaryExpr :: new ( cast_expr, Operator :: Eq , literal_expr) ) ;
828842
829843 let result = rewriter. rewrite ( binary_expr) ?;
830844
831845 // The result should be a binary expression with the cast unwrapped
832846 let result_binary = result. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ;
833-
847+
834848 // Left side should be the original column (no cast)
835- assert ! ( result_binary. left( ) . as_any( ) . downcast_ref:: <Column >( ) . is_some( ) ) ;
836-
849+ assert ! ( result_binary
850+ . left( )
851+ . as_any( )
852+ . downcast_ref:: <Column >( )
853+ . is_some( ) ) ;
854+
837855 // Right side should be a literal with the value cast to Int32
838- let right_literal = result_binary. right ( ) . as_any ( ) . downcast_ref :: < Literal > ( ) . unwrap ( ) ;
856+ let right_literal = result_binary
857+ . right ( )
858+ . as_any ( )
859+ . downcast_ref :: < Literal > ( )
860+ . unwrap ( ) ;
839861 assert_eq ! ( * right_literal. value( ) , ScalarValue :: Int32 ( Some ( 123 ) ) ) ;
840862
841863 Ok ( ( ) )
@@ -844,7 +866,8 @@ mod tests {
844866 #[ test]
845867 fn test_no_unwrap_cast_optimization_when_not_applicable ( ) -> Result < ( ) > {
846868 // Test case where optimization should not apply - unsupported cast
847- let physical_schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Float32 , false ) ] ) ;
869+ let physical_schema =
870+ Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Float32 , false ) ] ) ;
848871 let logical_schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Int64 , false ) ] ) ;
849872
850873 let rewriter = PhysicalExprSchemaRewriter :: new ( & physical_schema, & logical_schema) ;
@@ -854,18 +877,19 @@ mod tests {
854877 let column_expr = Arc :: new ( Column :: new ( "a" , 0 ) ) ;
855878 let cast_expr = Arc :: new ( CastExpr :: new ( column_expr, DataType :: Int64 , None ) ) ;
856879 let literal_expr = expressions:: lit ( ScalarValue :: Int64 ( Some ( 123 ) ) ) ;
857- let binary_expr = Arc :: new ( BinaryExpr :: new (
858- cast_expr,
859- Operator :: Eq ,
860- literal_expr,
861- ) ) ;
880+ let binary_expr =
881+ Arc :: new ( BinaryExpr :: new ( cast_expr, Operator :: Eq , literal_expr) ) ;
862882
863883 let result = rewriter. rewrite ( binary_expr) ?;
864884
865885 // The result should still be a binary expression with a cast on the left side
866886 // since Float32 is not in our supported types for unwrap cast optimization
867887 let result_binary = result. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ;
868- assert ! ( result_binary. left( ) . as_any( ) . downcast_ref:: <CastExpr >( ) . is_some( ) ) ;
888+ assert ! ( result_binary
889+ . left( )
890+ . as_any( )
891+ . downcast_ref:: <CastExpr >( )
892+ . is_some( ) ) ;
869893
870894 Ok ( ( ) )
871895 }
0 commit comments