Skip to content

Commit 8247e7e

Browse files
committed
cleanup
1 parent 8266def commit 8247e7e

File tree

3 files changed

+57
-80
lines changed

3 files changed

+57
-80
lines changed

datafusion/expr-common/src/casts.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -989,7 +989,7 @@ mod tests {
989989
}
990990

991991
#[test]
992-
fn test_string_view_comprehensive() {
992+
fn test_string_view() {
993993
// Test Utf8View to other string types
994994
expect_cast(
995995
ScalarValue::Utf8View(Some("test".to_string())),
@@ -1124,7 +1124,7 @@ mod tests {
11241124
}
11251125

11261126
#[test]
1127-
fn test_type_support_functions_comprehensive() {
1127+
fn test_type_support_functions() {
11281128
// Test numeric type support
11291129
assert!(is_supported_numeric_type(&DataType::Int8));
11301130
assert!(is_supported_numeric_type(&DataType::UInt64));
@@ -1182,7 +1182,7 @@ mod tests {
11821182
}
11831183

11841184
#[test]
1185-
fn test_error_conditions_comprehensive() {
1185+
fn test_error_conditions() {
11861186
// Test unsupported source type
11871187
expect_cast(
11881188
ScalarValue::Float32(Some(1.5)),

datafusion/physical-optimizer/src/simplify_expressions/mod.rs

Lines changed: 33 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -39,27 +39,15 @@ impl<'a> PhysicalExprSimplifier<'a> {
3939
}
4040

4141
/// Simplify a physical expression
42-
pub fn simplify(&self, expr: Arc<dyn PhysicalExpr>) -> Result<Arc<dyn PhysicalExpr>> {
43-
let mut simplifier = Simplifier {
44-
schema: self.schema,
45-
};
46-
Ok(expr.rewrite(&mut simplifier)?.data)
47-
}
48-
49-
/// Apply unwrap cast optimization to physical expressions
50-
pub fn unwrap_casts(
51-
&self,
42+
pub fn simplify(
43+
&mut self,
5244
expr: Arc<dyn PhysicalExpr>,
53-
) -> Result<Transformed<Arc<dyn PhysicalExpr>>> {
54-
unwrap_cast::unwrap_cast_in_comparison(expr, self.schema)
45+
) -> Result<Arc<dyn PhysicalExpr>> {
46+
Ok(expr.rewrite(self)?.data)
5547
}
5648
}
5749

58-
struct Simplifier<'a> {
59-
schema: &'a Schema,
60-
}
61-
62-
impl<'a> TreeNodeRewriter for Simplifier<'a> {
50+
impl<'a> TreeNodeRewriter for PhysicalExprSimplifier<'a> {
6351
type Node = Arc<dyn PhysicalExpr>;
6452

6553
fn f_up(&mut self, node: Self::Node) -> Result<Transformed<Self::Node>> {
@@ -76,7 +64,7 @@ mod tests {
7664
use datafusion_common::ScalarValue;
7765
use datafusion_expr::Operator;
7866
use datafusion_physical_expr::expressions::{
79-
binary, cast, col, lit, BinaryExpr, Literal,
67+
binary, cast, col, lit, BinaryExpr, CastExpr, Literal, TryCastExpr,
8068
};
8169

8270
fn test_schema() -> Schema {
@@ -88,41 +76,9 @@ mod tests {
8876
}
8977

9078
#[test]
91-
fn test_physical_expr_simplifier_integration() {
92-
let schema = test_schema();
93-
let simplifier = PhysicalExprSimplifier::new(&schema);
94-
95-
// Create: cast(c1 as INT64) = INT64(42)
96-
let column_expr = col("c1", &schema).unwrap();
97-
let cast_expr = cast(column_expr, &schema, DataType::Int64).unwrap();
98-
let literal_expr = lit(ScalarValue::Int64(Some(42)));
99-
let binary_expr = binary(cast_expr, Operator::Eq, literal_expr, &schema).unwrap();
100-
101-
// Apply simplification
102-
let result = simplifier.unwrap_casts(binary_expr).unwrap();
103-
104-
// Should be transformed to: c1 = INT32(42)
105-
assert!(result.transformed);
106-
107-
let optimized = result.data;
108-
let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
109-
110-
// Verify the cast was removed
111-
assert!(!unwrap_cast::is_cast_expr(optimized_binary.left()));
112-
113-
// Verify the literal was converted to the correct type
114-
let right_literal = optimized_binary
115-
.right()
116-
.as_any()
117-
.downcast_ref::<Literal>()
118-
.unwrap();
119-
assert_eq!(right_literal.value(), &ScalarValue::Int32(Some(42)));
120-
}
121-
122-
#[test]
123-
fn test_simplify_method() {
79+
fn test_simplify() {
12480
let schema = test_schema();
125-
let simplifier = PhysicalExprSimplifier::new(&schema);
81+
let mut simplifier = PhysicalExprSimplifier::new(&schema);
12682

12783
// Create: cast(c2 as INT32) != INT32(99)
12884
let column_expr = col("c2", &schema).unwrap();
@@ -137,7 +93,11 @@ mod tests {
13793
let optimized_binary = optimized.as_any().downcast_ref::<BinaryExpr>().unwrap();
13894

13995
// Should be optimized to: c2 != INT64(99) (c2 is INT64, literal cast to match)
140-
assert!(!unwrap_cast::is_cast_expr(optimized_binary.left()));
96+
let left_expr = optimized_binary.left();
97+
assert!(
98+
left_expr.as_any().downcast_ref::<CastExpr>().is_none()
99+
&& left_expr.as_any().downcast_ref::<TryCastExpr>().is_none()
100+
);
141101
let right_literal = optimized_binary
142102
.right()
143103
.as_any()
@@ -149,7 +109,7 @@ mod tests {
149109
#[test]
150110
fn test_nested_expression_simplification() {
151111
let schema = test_schema();
152-
let simplifier = PhysicalExprSimplifier::new(&schema);
112+
let mut simplifier = PhysicalExprSimplifier::new(&schema);
153113

154114
// Create nested expression: (cast(c1 as INT64) > INT64(5)) OR (cast(c2 as INT32) <= INT32(10))
155115
let c1_expr = col("c1", &schema).unwrap();
@@ -175,7 +135,14 @@ mod tests {
175135
.as_any()
176136
.downcast_ref::<BinaryExpr>()
177137
.unwrap();
178-
assert!(!unwrap_cast::is_cast_expr(left_binary.left()));
138+
let left_left_expr = left_binary.left();
139+
assert!(
140+
left_left_expr.as_any().downcast_ref::<CastExpr>().is_none()
141+
&& left_left_expr
142+
.as_any()
143+
.downcast_ref::<TryCastExpr>()
144+
.is_none()
145+
);
179146
let left_literal = left_binary
180147
.right()
181148
.as_any()
@@ -189,7 +156,17 @@ mod tests {
189156
.as_any()
190157
.downcast_ref::<BinaryExpr>()
191158
.unwrap();
192-
assert!(!unwrap_cast::is_cast_expr(right_binary.left()));
159+
let right_left_expr = right_binary.left();
160+
assert!(
161+
right_left_expr
162+
.as_any()
163+
.downcast_ref::<CastExpr>()
164+
.is_none()
165+
&& right_left_expr
166+
.as_any()
167+
.downcast_ref::<TryCastExpr>()
168+
.is_none()
169+
);
193170
let right_literal = right_binary
194171
.right()
195172
.as_any()

datafusion/physical-optimizer/src/simplify_expressions/unwrap_cast.rs

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ fn try_unwrap_cast_binary(
109109
return Ok(Some(datafusion_physical_expr::expressions::binary(
110110
unwrapped,
111111
swapped_op,
112-
binary.left().clone(),
112+
Arc::clone(binary.left()),
113113
schema,
114114
)?));
115115
}
@@ -175,26 +175,6 @@ fn swap_operator(op: Operator) -> Operator {
175175
}
176176
}
177177

178-
/// Check if an expression is a cast expression
179-
pub(super) fn is_cast_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
180-
expr.as_any().downcast_ref::<CastExpr>().is_some()
181-
|| expr.as_any().downcast_ref::<TryCastExpr>().is_some()
182-
}
183-
184-
/// Check if a binary expression is suitable for cast unwrapping
185-
pub(super) fn is_binary_expr_with_cast_and_literal(binary: &BinaryExpr) -> bool {
186-
// Check if left is cast and right is literal
187-
let left_cast_right_literal = is_cast_expr(binary.left())
188-
&& binary.right().as_any().downcast_ref::<Literal>().is_some();
189-
190-
// Check if left is literal and right is cast
191-
let left_literal_right_cast =
192-
binary.left().as_any().downcast_ref::<Literal>().is_some()
193-
&& is_cast_expr(binary.right());
194-
195-
left_cast_right_literal || left_literal_right_cast
196-
}
197-
198178
#[cfg(test)]
199179
mod tests {
200180
use super::*;
@@ -203,6 +183,26 @@ mod tests {
203183
use datafusion_expr::Operator;
204184
use datafusion_physical_expr::expressions::{binary, cast, col, lit};
205185

186+
/// Check if an expression is a cast expression
187+
fn is_cast_expr(expr: &Arc<dyn PhysicalExpr>) -> bool {
188+
expr.as_any().downcast_ref::<CastExpr>().is_some()
189+
|| expr.as_any().downcast_ref::<TryCastExpr>().is_some()
190+
}
191+
192+
/// Check if a binary expression is suitable for cast unwrapping
193+
fn is_binary_expr_with_cast_and_literal(binary: &BinaryExpr) -> bool {
194+
// Check if left is cast and right is literal
195+
let left_cast_right_literal = is_cast_expr(binary.left())
196+
&& binary.right().as_any().downcast_ref::<Literal>().is_some();
197+
198+
// Check if left is literal and right is cast
199+
let left_literal_right_cast =
200+
binary.left().as_any().downcast_ref::<Literal>().is_some()
201+
&& is_cast_expr(binary.right());
202+
203+
left_cast_right_literal || left_literal_right_cast
204+
}
205+
206206
fn test_schema() -> Schema {
207207
Schema::new(vec![
208208
Field::new("c1", DataType::Int32, false),

0 commit comments

Comments
 (0)