Skip to content

Commit

Permalink
add rules (#3627)
Browse files Browse the repository at this point in the history
Signed-off-by: remzi <13716567376yh@gmail.com>

Signed-off-by: remzi <13716567376yh@gmail.com>
  • Loading branch information
HaoYang670 authored Sep 28, 2022
1 parent d1e3bd7 commit 41b59cf
Showing 1 changed file with 82 additions and 1 deletion.
83 changes: 82 additions & 1 deletion datafusion/optimizer/src/simplify_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl<'a, 'b> SimplifyInfo for SimplifyContext<'a, 'b> {
// This means we weren't able to compute `Expr::nullable` with
// *any* input schemas, signalling a problem
DataFusionError::Internal(format!(
"Could not find find columns in '{}' during simplify",
"Could not find columns in '{}' during simplify",
expr
))
})
Expand Down Expand Up @@ -110,6 +110,22 @@ fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool {
}
}

fn is_zero(s: &Expr) -> bool {
match s {
Expr::Literal(ScalarValue::Int8(Some(0)))
| Expr::Literal(ScalarValue::Int16(Some(0)))
| Expr::Literal(ScalarValue::Int32(Some(0)))
| Expr::Literal(ScalarValue::Int64(Some(0)))
| Expr::Literal(ScalarValue::UInt8(Some(0)))
| Expr::Literal(ScalarValue::UInt16(Some(0)))
| Expr::Literal(ScalarValue::UInt32(Some(0)))
| Expr::Literal(ScalarValue::UInt64(Some(0))) => true,
Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 0. => true,
Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 0. => true,
_ => false,
}
}

fn is_one(s: &Expr) -> bool {
match s {
Expr::Literal(ScalarValue::Int8(Some(1)))
Expand Down Expand Up @@ -728,16 +744,44 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> {
//
// Rules for Multiply
//

// A * 1 --> A
BinaryExpr {
left,
op: Multiply,
right,
} if is_one(&right) => *left,
// 1 * A --> A
BinaryExpr {
left,
op: Multiply,
right,
} if is_one(&left) => *right,
// A * null --> null
BinaryExpr {
left: _,
op: Multiply,
right,
} if is_null(&right) => *right,
// null * A --> null
BinaryExpr {
left,
op: Multiply,
right: _,
} if is_null(&left) => *left,

// A * 0 --> 0 (if A is not null)
BinaryExpr {
left,
op: Multiply,
right,
} if !info.nullable(&left)? && is_zero(&right) => *right,
// 0 * A --> 0 (if A is not null)
BinaryExpr {
left,
op: Multiply,
right,
} if !info.nullable(&right)? && is_zero(&left) => *left,

//
// Rules for Divide
Expand Down Expand Up @@ -971,6 +1015,43 @@ mod tests {
assert_eq!(simplify(expr_b), expected);
}

#[test]
fn test_simplify_multiply_by_null() {
let null = Expr::Literal(ScalarValue::Null);
// A * null --> null
{
let expr = binary_expr(col("c2"), Operator::Multiply, null.clone());
assert_eq!(simplify(expr), null);
}
// null * A --> null
{
let expr = binary_expr(null.clone(), Operator::Multiply, col("c2"));
assert_eq!(simplify(expr), null);
}
}

#[test]
fn test_simplify_multiply_by_zero() {
// cannot optimize A * null (null * A) if A is nullable
{
let expr_a = binary_expr(col("c2"), Operator::Multiply, lit(0));
let expr_b = binary_expr(lit(0), Operator::Multiply, col("c2"));

assert_eq!(simplify(expr_a.clone()), expr_a);
assert_eq!(simplify(expr_b.clone()), expr_b);
}
// 0 * A --> 0 if A is not nullable
{
let expr = binary_expr(lit(0), Operator::Multiply, col("c2_non_null"));
assert_eq!(simplify(expr), lit(0));
}
// A * 0 --> 0 if A is not nullable
{
let expr = binary_expr(col("c2_non_null"), Operator::Multiply, lit(0));
assert_eq!(simplify(expr), lit(0));
}
}

#[test]
fn test_simplify_divide_by_one() {
let expr = binary_expr(col("c2"), Operator::Divide, lit(1));
Expand Down

0 comments on commit 41b59cf

Please sign in to comment.