diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index 4aee2fd0378a..94a035fbd64e 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -72,7 +72,7 @@ impl<'a, 'b> SimplifyInfo for SimplifyContext<'a, 'b> { // This means we weren't able to compute `Expr::nullable` with // *any* input schemas, signalling a problem DataFusionError::Internal(format!( - "Could not find find columns in '{}' during simplify", + "Could not find columns in '{}' during simplify", expr )) }) @@ -110,6 +110,22 @@ fn expr_contains(expr: &Expr, needle: &Expr, search_op: Operator) -> bool { } } +fn is_zero(s: &Expr) -> bool { + match s { + Expr::Literal(ScalarValue::Int8(Some(0))) + | Expr::Literal(ScalarValue::Int16(Some(0))) + | Expr::Literal(ScalarValue::Int32(Some(0))) + | Expr::Literal(ScalarValue::Int64(Some(0))) + | Expr::Literal(ScalarValue::UInt8(Some(0))) + | Expr::Literal(ScalarValue::UInt16(Some(0))) + | Expr::Literal(ScalarValue::UInt32(Some(0))) + | Expr::Literal(ScalarValue::UInt64(Some(0))) => true, + Expr::Literal(ScalarValue::Float32(Some(v))) if *v == 0. => true, + Expr::Literal(ScalarValue::Float64(Some(v))) if *v == 0. => true, + _ => false, + } +} + fn is_one(s: &Expr) -> bool { match s { Expr::Literal(ScalarValue::Int8(Some(1))) @@ -728,16 +744,44 @@ impl<'a, S: SimplifyInfo> ExprRewriter for Simplifier<'a, S> { // // Rules for Multiply // + + // A * 1 --> A BinaryExpr { left, op: Multiply, right, } if is_one(&right) => *left, + // 1 * A --> A BinaryExpr { left, op: Multiply, right, } if is_one(&left) => *right, + // A * null --> null + BinaryExpr { + left: _, + op: Multiply, + right, + } if is_null(&right) => *right, + // null * A --> null + BinaryExpr { + left, + op: Multiply, + right: _, + } if is_null(&left) => *left, + + // A * 0 --> 0 (if A is not null) + BinaryExpr { + left, + op: Multiply, + right, + } if !info.nullable(&left)? && is_zero(&right) => *right, + // 0 * A --> 0 (if A is not null) + BinaryExpr { + left, + op: Multiply, + right, + } if !info.nullable(&right)? && is_zero(&left) => *left, // // Rules for Divide @@ -971,6 +1015,43 @@ mod tests { assert_eq!(simplify(expr_b), expected); } + #[test] + fn test_simplify_multiply_by_null() { + let null = Expr::Literal(ScalarValue::Null); + // A * null --> null + { + let expr = binary_expr(col("c2"), Operator::Multiply, null.clone()); + assert_eq!(simplify(expr), null); + } + // null * A --> null + { + let expr = binary_expr(null.clone(), Operator::Multiply, col("c2")); + assert_eq!(simplify(expr), null); + } + } + + #[test] + fn test_simplify_multiply_by_zero() { + // cannot optimize A * null (null * A) if A is nullable + { + let expr_a = binary_expr(col("c2"), Operator::Multiply, lit(0)); + let expr_b = binary_expr(lit(0), Operator::Multiply, col("c2")); + + assert_eq!(simplify(expr_a.clone()), expr_a); + assert_eq!(simplify(expr_b.clone()), expr_b); + } + // 0 * A --> 0 if A is not nullable + { + let expr = binary_expr(lit(0), Operator::Multiply, col("c2_non_null")); + assert_eq!(simplify(expr), lit(0)); + } + // A * 0 --> 0 if A is not nullable + { + let expr = binary_expr(col("c2_non_null"), Operator::Multiply, lit(0)); + assert_eq!(simplify(expr), lit(0)); + } + } + #[test] fn test_simplify_divide_by_one() { let expr = binary_expr(col("c2"), Operator::Divide, lit(1));