@@ -39,27 +39,15 @@ impl<'a> PhysicalExprSimplifier<'a> {
3939 }
4040
4141 /// Simplify a physical expression
42- pub fn simplify ( & self , expr : Arc < dyn PhysicalExpr > ) -> Result < Arc < dyn PhysicalExpr > > {
43- let mut simplifier = Simplifier {
44- schema : self . schema ,
45- } ;
46- Ok ( expr. rewrite ( & mut simplifier) ?. data )
47- }
48-
49- /// Apply unwrap cast optimization to physical expressions
50- pub fn unwrap_casts (
51- & self ,
42+ pub fn simplify (
43+ & mut self ,
5244 expr : Arc < dyn PhysicalExpr > ,
53- ) -> Result < Transformed < Arc < dyn PhysicalExpr > > > {
54- unwrap_cast :: unwrap_cast_in_comparison ( expr, self . schema )
45+ ) -> Result < Arc < dyn PhysicalExpr > > {
46+ Ok ( expr. rewrite ( self ) ? . data )
5547 }
5648}
5749
58- struct Simplifier < ' a > {
59- schema : & ' a Schema ,
60- }
61-
62- impl < ' a > TreeNodeRewriter for Simplifier < ' a > {
50+ impl < ' a > TreeNodeRewriter for PhysicalExprSimplifier < ' a > {
6351 type Node = Arc < dyn PhysicalExpr > ;
6452
6553 fn f_up ( & mut self , node : Self :: Node ) -> Result < Transformed < Self :: Node > > {
@@ -76,7 +64,7 @@ mod tests {
7664 use datafusion_common:: ScalarValue ;
7765 use datafusion_expr:: Operator ;
7866 use datafusion_physical_expr:: expressions:: {
79- binary, cast, col, lit, BinaryExpr , Literal ,
67+ binary, cast, col, lit, BinaryExpr , CastExpr , Literal , TryCastExpr ,
8068 } ;
8169
8270 fn test_schema ( ) -> Schema {
@@ -88,41 +76,9 @@ mod tests {
8876 }
8977
9078 #[ test]
91- fn test_physical_expr_simplifier_integration ( ) {
92- let schema = test_schema ( ) ;
93- let simplifier = PhysicalExprSimplifier :: new ( & schema) ;
94-
95- // Create: cast(c1 as INT64) = INT64(42)
96- let column_expr = col ( "c1" , & schema) . unwrap ( ) ;
97- let cast_expr = cast ( column_expr, & schema, DataType :: Int64 ) . unwrap ( ) ;
98- let literal_expr = lit ( ScalarValue :: Int64 ( Some ( 42 ) ) ) ;
99- let binary_expr = binary ( cast_expr, Operator :: Eq , literal_expr, & schema) . unwrap ( ) ;
100-
101- // Apply simplification
102- let result = simplifier. unwrap_casts ( binary_expr) . unwrap ( ) ;
103-
104- // Should be transformed to: c1 = INT32(42)
105- assert ! ( result. transformed) ;
106-
107- let optimized = result. data ;
108- let optimized_binary = optimized. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ;
109-
110- // Verify the cast was removed
111- assert ! ( !unwrap_cast:: is_cast_expr( optimized_binary. left( ) ) ) ;
112-
113- // Verify the literal was converted to the correct type
114- let right_literal = optimized_binary
115- . right ( )
116- . as_any ( )
117- . downcast_ref :: < Literal > ( )
118- . unwrap ( ) ;
119- assert_eq ! ( right_literal. value( ) , & ScalarValue :: Int32 ( Some ( 42 ) ) ) ;
120- }
121-
122- #[ test]
123- fn test_simplify_method ( ) {
79+ fn test_simplify ( ) {
12480 let schema = test_schema ( ) ;
125- let simplifier = PhysicalExprSimplifier :: new ( & schema) ;
81+ let mut simplifier = PhysicalExprSimplifier :: new ( & schema) ;
12682
12783 // Create: cast(c2 as INT32) != INT32(99)
12884 let column_expr = col ( "c2" , & schema) . unwrap ( ) ;
@@ -137,7 +93,11 @@ mod tests {
13793 let optimized_binary = optimized. as_any ( ) . downcast_ref :: < BinaryExpr > ( ) . unwrap ( ) ;
13894
13995 // Should be optimized to: c2 != INT64(99) (c2 is INT64, literal cast to match)
140- assert ! ( !unwrap_cast:: is_cast_expr( optimized_binary. left( ) ) ) ;
96+ let left_expr = optimized_binary. left ( ) ;
97+ assert ! (
98+ left_expr. as_any( ) . downcast_ref:: <CastExpr >( ) . is_none( )
99+ && left_expr. as_any( ) . downcast_ref:: <TryCastExpr >( ) . is_none( )
100+ ) ;
141101 let right_literal = optimized_binary
142102 . right ( )
143103 . as_any ( )
@@ -149,7 +109,7 @@ mod tests {
149109 #[ test]
150110 fn test_nested_expression_simplification ( ) {
151111 let schema = test_schema ( ) ;
152- let simplifier = PhysicalExprSimplifier :: new ( & schema) ;
112+ let mut simplifier = PhysicalExprSimplifier :: new ( & schema) ;
153113
154114 // Create nested expression: (cast(c1 as INT64) > INT64(5)) OR (cast(c2 as INT32) <= INT32(10))
155115 let c1_expr = col ( "c1" , & schema) . unwrap ( ) ;
@@ -175,7 +135,14 @@ mod tests {
175135 . as_any ( )
176136 . downcast_ref :: < BinaryExpr > ( )
177137 . unwrap ( ) ;
178- assert ! ( !unwrap_cast:: is_cast_expr( left_binary. left( ) ) ) ;
138+ let left_left_expr = left_binary. left ( ) ;
139+ assert ! (
140+ left_left_expr. as_any( ) . downcast_ref:: <CastExpr >( ) . is_none( )
141+ && left_left_expr
142+ . as_any( )
143+ . downcast_ref:: <TryCastExpr >( )
144+ . is_none( )
145+ ) ;
179146 let left_literal = left_binary
180147 . right ( )
181148 . as_any ( )
@@ -189,7 +156,17 @@ mod tests {
189156 . as_any ( )
190157 . downcast_ref :: < BinaryExpr > ( )
191158 . unwrap ( ) ;
192- assert ! ( !unwrap_cast:: is_cast_expr( right_binary. left( ) ) ) ;
159+ let right_left_expr = right_binary. left ( ) ;
160+ assert ! (
161+ right_left_expr
162+ . as_any( )
163+ . downcast_ref:: <CastExpr >( )
164+ . is_none( )
165+ && right_left_expr
166+ . as_any( )
167+ . downcast_ref:: <TryCastExpr >( )
168+ . is_none( )
169+ ) ;
193170 let right_literal = right_binary
194171 . right ( )
195172 . as_any ( )
0 commit comments