|
20 | 20 | //! [`ExprSimplifier::with_guarantees()`]: crate::simplify_expressions::expr_simplifier::ExprSimplifier::with_guarantees |
21 | 21 | use datafusion_common::{tree_node::TreeNodeRewriter, DataFusionError, Result}; |
22 | 22 | use datafusion_expr::{expr::InList, lit, Between, BinaryExpr, Expr}; |
23 | | -use std::collections::HashMap; |
| 23 | +use std::{borrow::Cow, collections::HashMap}; |
24 | 24 |
|
25 | 25 | use datafusion_physical_expr::intervals::{Interval, IntervalBound, NullableInterval}; |
26 | 26 |
|
@@ -103,37 +103,44 @@ impl<'a> TreeNodeRewriter for GuaranteeRewriter<'a> { |
103 | 103 | } |
104 | 104 |
|
105 | 105 | Expr::BinaryExpr(BinaryExpr { left, op, right }) => { |
106 | | - // We only support comparisons for now |
107 | | - if !op.is_comparison_operator() { |
108 | | - return Ok(expr); |
109 | | - }; |
110 | | - |
111 | | - // Check if this is a comparison between a column and literal |
112 | | - let (col, op, value) = match (left.as_ref(), right.as_ref()) { |
113 | | - (Expr::Column(_), Expr::Literal(value)) => (left, *op, value), |
114 | | - (Expr::Literal(value), Expr::Column(_)) => { |
115 | | - // If we can swap the op, we can simplify the expression |
116 | | - if let Some(op) = op.swap() { |
117 | | - (right, op, value) |
| 106 | + // The left or right side of expression might either have a guarantee |
| 107 | + // or be a literal. Either way, we can resolve them to a NullableInterval. |
| 108 | + let left_interval = self |
| 109 | + .guarantees |
| 110 | + .get(left.as_ref()) |
| 111 | + .map(|interval| Cow::Borrowed(*interval)) |
| 112 | + .or_else(|| { |
| 113 | + if let Expr::Literal(value) = left.as_ref() { |
| 114 | + Some(Cow::Owned(value.clone().into())) |
118 | 115 | } else { |
119 | | - return Ok(expr); |
| 116 | + None |
| 117 | + } |
| 118 | + }); |
| 119 | + let right_interval = self |
| 120 | + .guarantees |
| 121 | + .get(right.as_ref()) |
| 122 | + .map(|interval| Cow::Borrowed(*interval)) |
| 123 | + .or_else(|| { |
| 124 | + if let Expr::Literal(value) = right.as_ref() { |
| 125 | + Some(Cow::Owned(value.clone().into())) |
| 126 | + } else { |
| 127 | + None |
| 128 | + } |
| 129 | + }); |
| 130 | + |
| 131 | + match (left_interval, right_interval) { |
| 132 | + (Some(left_interval), Some(right_interval)) => { |
| 133 | + let result = |
| 134 | + left_interval.apply_operator(op, right_interval.as_ref())?; |
| 135 | + if result.is_certainly_true() { |
| 136 | + Ok(lit(true)) |
| 137 | + } else if result.is_certainly_false() { |
| 138 | + Ok(lit(false)) |
| 139 | + } else { |
| 140 | + Ok(expr) |
120 | 141 | } |
121 | 142 | } |
122 | | - _ => return Ok(expr), |
123 | | - }; |
124 | | - |
125 | | - if let Some(col_interval) = self.guarantees.get(col.as_ref()) { |
126 | | - let result = |
127 | | - col_interval.apply_operator(&op, &value.clone().into())?; |
128 | | - if result.is_certainly_true() { |
129 | | - Ok(lit(true)) |
130 | | - } else if result.is_certainly_false() { |
131 | | - Ok(lit(false)) |
132 | | - } else { |
133 | | - Ok(expr) |
134 | | - } |
135 | | - } else { |
136 | | - Ok(expr) |
| 143 | + _ => Ok(expr), |
137 | 144 | } |
138 | 145 | } |
139 | 146 |
|
@@ -262,13 +269,21 @@ mod tests { |
262 | 269 | values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), |
263 | 270 | }, |
264 | 271 | ), |
| 272 | + // s.y ∈ (1, 3] (not null) |
| 273 | + ( |
| 274 | + col("s").field("y"), |
| 275 | + NullableInterval::NotNull { |
| 276 | + values: Interval::make(Some(1_i32), Some(3_i32), (true, false)), |
| 277 | + }, |
| 278 | + ), |
265 | 279 | ]; |
266 | 280 |
|
267 | 281 | let mut rewriter = GuaranteeRewriter::new(guarantees.iter()); |
268 | 282 |
|
269 | 283 | // (original_expr, expected_simplification) |
270 | 284 | let simplified_cases = &[ |
271 | 285 | (col("x").lt_eq(lit(1)), false), |
| 286 | + (col("s").field("y").lt_eq(lit(1)), false), |
272 | 287 | (col("x").lt_eq(lit(3)), true), |
273 | 288 | (col("x").gt(lit(3)), false), |
274 | 289 | (col("x").gt(lit(1)), true), |
|
0 commit comments