diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ad258986f785b..65c3064531b35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -111,6 +111,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) ConstantPropagation, FoldablePropagation, OptimizeIn, + RewriteArithmeticFiltersOnIntegralColumn, ConstantFolding, ReorderAssociativeOperator, LikeSimplification, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumn.scala new file mode 100644 index 0000000000000..8cd3fa46829fc --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumn.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types._ + +/** + * Rewrite arithmetic filters on an integral-type (e.g., byte, short, int and long) + * column to its equivalent form, leaving attribute alone in a left side, so that + * we can push it down to datasources (e.g., Parquet and ORC). + * + * For example, this rule can optimize a query as follows: + * {{{ + * SELECT * FROM table WHERE i + 3 = 5 + * ==> SELECT * FROM table WHERE i = 5 - 3 + * }}} + * + * Then, the [[ConstantFolding]] rule will further optimize it as follows: + * {{{ + * SELECT * FROM table WHERE i = 2 + * }}} + * + * Note: + * 1. This rule supports `Add` and `Subtract` in arithmetic expressions. + * 2. This rule supports `=` and `!=` in comparators. For `>`, `>=`, `<`, `<=`, + * it may brings inconsistencies after rewrite. + * 3. This rule supports integral-type (`byte`, `short`, `int`, `long`) only. + * It doesn't support `float` or `double` because of precision issues. + */ +object RewriteArithmeticFiltersOnIntegralColumn extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f: Filter => + f transformExpressionsUp { + case e @ EqualTo(left: BinaryArithmetic, right: Expression) + if right.foldable && isDataTypeSafe(left.dataType) => + transformLeft(e, left, right) + case e @ EqualTo(left: Expression, right: BinaryArithmetic) + if left.foldable && isDataTypeSafe(right.dataType) => + transformRight(e, left, right) + } + } + + private def transformLeft( + bc: EqualTo, + left: BinaryArithmetic, + right: Expression): Expression = { + left match { + case Add(ar: AttributeReference, e) if e.foldable && isOptSafe(Subtract(right, e)) => + bc.makeCopy(Array(ar, Subtract(right, e))) + case Add(e, ar: AttributeReference) if e.foldable && isOptSafe(Subtract(right, e)) => + bc.makeCopy(Array(ar, Subtract(right, e))) + case Subtract(ar: AttributeReference, e) if e.foldable && isOptSafe(Add(right, e)) => + bc.makeCopy(Array(ar, Add(right, e))) + case Subtract(e, ar: AttributeReference) if e.foldable && isOptSafe(Subtract(e, right)) => + bc.makeCopy(Array(Subtract(e, right), ar)) + case _ => bc + } + } + + private def transformRight( + bc: EqualTo, + left: Expression, + right: BinaryArithmetic): Expression = { + right match { + case Add(ar: AttributeReference, e) if e.foldable && isOptSafe(Subtract(left, e)) => + bc.makeCopy(Array(Subtract(left, e), ar)) + case Add(e, ar: AttributeReference) if e.foldable && isOptSafe(Subtract(left, e)) => + bc.makeCopy(Array(Subtract(left, e), ar)) + case Subtract(ar: AttributeReference, e) if e.foldable && isOptSafe(Add(left, e)) => + bc.makeCopy(Array(Add(left, e), ar)) + case Subtract(e, ar: AttributeReference) if e.foldable && isOptSafe(Subtract(e, left)) => + bc.makeCopy(Array(ar, Subtract(e, left))) + case _ => bc + } + } + + private def isDataTypeSafe(dataType: DataType): Boolean = dataType match { + case ByteType | ShortType | IntegerType | LongType => true + case _ => false + } + + private def isOptSafe(e: BinaryArithmetic): Boolean = { + val leftVal = e.left.eval(EmptyRow) + val rightVal = e.right.eval(EmptyRow) + + e match { + case Add(_, _) => + e.dataType match { + case ByteType => + isAddSafe(leftVal, rightVal, Byte.MinValue, Byte.MaxValue) + case ShortType => + isAddSafe(leftVal, rightVal, Short.MinValue, Short.MaxValue) + case IntegerType => + isAddSafe(leftVal, rightVal, Int.MinValue, Int.MaxValue) + case LongType => + isAddSafe(leftVal, rightVal, Long.MinValue, Long.MaxValue) + case _ => false + } + + case Subtract(_, _) => + e.dataType match { + case ByteType => + isSubtractSafe(leftVal, rightVal, Byte.MinValue, Byte.MaxValue) + case ShortType => + isSubtractSafe(leftVal, rightVal, Short.MinValue, Short.MaxValue) + case IntegerType => + isSubtractSafe(leftVal, rightVal, Int.MinValue, Int.MaxValue) + case LongType => + isSubtractSafe(leftVal, rightVal, Long.MinValue, Long.MaxValue) + case _ => false + } + + case _ => false + } + } + + private def isAddSafe[T](left: Any, right: Any, minValue: T, maxValue: T)( + implicit num: Numeric[T]): Boolean = { + import num._ + val leftVal = left.asInstanceOf[T] + val rightVal = right.asInstanceOf[T] + if (rightVal > zero) { + leftVal <= maxValue - rightVal + } else { + leftVal >= minValue - rightVal + } + } + + private def isSubtractSafe[T](left: Any, right: Any, minValue: T, maxValue: T)( + implicit num: Numeric[T]): Boolean = { + import num._ + val leftVal = left.asInstanceOf[T] + val rightVal = right.asInstanceOf[T] + if (rightVal > zero) { + leftVal >= minValue + rightVal + } else { + leftVal <= maxValue + rightVal + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumnSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumnSuite.scala new file mode 100644 index 0000000000000..dfef8990b522c --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumnSuite.scala @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +/** + * Unit tests for rewrite arithmetic filters on integral column optimizer. + */ +class RewriteArithmeticFiltersOnIntegralColumnSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch( + "RewriteArithmeticFiltersOnIntegralColumn", + FixedPoint(10), + RewriteArithmeticFiltersOnIntegralColumn, + ConstantFolding) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.long, 'c.byte, 'd.short) + + private val columnA = 'a + private val columnB = 'b + private val columnC = 'c + private val columnD = 'd + + test("test of int: a + 2 = 8") { + val query = testRelation + .where(Add(columnA, Literal(2)) === Literal(8)) + + val correctAnswer = testRelation + .where(columnA === Literal(6)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int: 2 - a =!= 8") { + val query = testRelation + .where(Subtract(Literal(2), columnA) =!= Literal(8)) + + val correctAnswer = testRelation + .where(Literal(-6) =!= columnA) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int with overflow risk: a - 10 = Int.MaxValue - 2") { + val query = testRelation + .where(Subtract(columnA, Literal(10)) === Literal(Int.MaxValue - 2)) + + val correctAnswer = testRelation + .where(Subtract(columnA, Literal(10)) === Literal(Int.MaxValue - 2)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int with overflow risk: 10 - a = Int.MinValue") { + val query = testRelation + .where(Subtract(Literal(10), columnA) === Literal(Int.MinValue)) + + val correctAnswer = testRelation + .where(Subtract(Literal(10), columnA) === Literal(Int.MinValue)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int with overflow risk: a + 10 = Int.MinValue + 2") { + val query = testRelation + .where(Add(columnA, Literal(10)) === Literal(Int.MinValue + 2)) + + val correctAnswer = testRelation + .where(Add(columnA, Literal(10)) === Literal(Int.MinValue + 2)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: b + 2L = 8L") { + val query = testRelation + .where(Add(columnB, Literal(2L)) === Literal(8L)) + + val correctAnswer = testRelation + .where(columnB === Literal(6L)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: 2L - b =!= 8L") { + val query = testRelation + .where(Subtract(Literal(2L), columnB) =!= Literal(8L)) + + val correctAnswer = testRelation + .where(Literal(-6L) =!= columnB) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long with overflow risk: b - 10L = Long.MaxValue - 2") { + val query = testRelation + .where(Subtract(columnB, Literal(10L)) === Literal(Long.MaxValue - 2)) + + val correctAnswer = testRelation + .where(Subtract(columnB, Literal(10L)) === Literal(Long.MaxValue - 2)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long with overflow risk: 10L - b = Long.MinValue") { + val query = testRelation + .where(Subtract(Literal(10L), columnB) === Literal(Long.MinValue)) + + val correctAnswer = testRelation + .where(Subtract(Literal(10L), columnB) === Literal(Long.MinValue)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long with overflow risk: bL + 10 = Long.MinValue + 2") { + val query = testRelation + .where(Add(columnB, Literal(10L)) === Literal(Long.MinValue + 2)) + + val correctAnswer = testRelation + .where(Add(columnB, Literal(10L)) === Literal(Long.MinValue + 2)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte: c + 2 = 8") { + val query = testRelation + .where(Add(columnC, Literal(2.toByte)) === Literal(8.toByte)) + + val correctAnswer = testRelation + .where(columnC === Literal(6.toByte)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte: 2 - c =!= 8") { + val query = testRelation + .where(Subtract(Literal(2.toByte), columnC) =!= Literal(8.toByte)) + + val correctAnswer = testRelation + .where(Literal(-6.toByte) =!= columnC) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte with overflow risk: c - 10 = Byte.MaxValue - 2") { + val query = testRelation + .where(Subtract(columnC, Literal(10.toByte)) === Literal(Byte.MaxValue - 2.toByte)) + + val correctAnswer = testRelation + .where(Subtract(columnC, Literal(10.toByte)) === Literal(Byte.MaxValue - 2.toByte)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte with overflow risk: 10 - c = Byte.MinValue") { + val query = testRelation + .where(Subtract(Literal(10.toByte), columnC) === Literal(Byte.MinValue)) + + val correctAnswer = testRelation + .where(Subtract(Literal(10.toByte), columnC) === Literal(Byte.MinValue)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte with overflow risk: c + 10 = Byte.MinValue + 2") { + val query = testRelation + .where(Add(columnC, Literal(10.toByte)) === Literal(Byte.MinValue + 2.toByte)) + + val correctAnswer = testRelation + .where(Add(columnC, Literal(10.toByte)) === Literal(Byte.MinValue + 2.toByte)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short: d + 2 = 8") { + val query = testRelation + .where(Add(columnD, Literal(2.toShort)) === Literal(8.toShort)) + + val correctAnswer = testRelation + .where(columnD === Literal(6.toShort)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short: 2 - d =!= 8") { + val query = testRelation + .where(Subtract(Literal(2.toShort), columnD) =!= Literal(8.toShort)) + + val correctAnswer = testRelation + .where(Literal(-6.toShort) =!= columnD) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short with overflow risk: d - 10 = Short.MaxValue - 2") { + val query = testRelation + .where(Subtract(columnD, Literal(10.toShort)) === Literal(Short.MaxValue - 2.toShort)) + + val correctAnswer = testRelation + .where(Subtract(columnD, Literal(10.toShort)) === Literal(Short.MaxValue - 2.toShort)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short with overflow risk: 10 - d = Short.MinValue") { + val query = testRelation + .where(Subtract(Literal(10.toShort), columnD) === Literal(Short.MinValue)) + + val correctAnswer = testRelation + .where(Subtract(Literal(10.toShort), columnD) === Literal(Short.MinValue)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short with overflow risk: d + 10 = Short.MinValue + 2") { + val query = testRelation + .where(Add(columnD, Literal(10.toShort)) === Literal(Short.MinValue + 2.toShort)) + + val correctAnswer = testRelation + .where(Add(columnD, Literal(10.toShort)) === Literal(Short.MinValue + 2.toShort)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index a4dc537d31b7e..7899eaaa6450a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -345,7 +345,8 @@ class JDBCSuite extends QueryTest assert(df2.collect.toSet === Set(Row("fred", 1), Row("mary", 2))) assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) - assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 2) != 4")).collect().size == 2) + // SPARK-27033: Add Optimize rule RewriteArithmeticFiltersOnIntOrLongColumn + assert(checkPushdown(sql("SELECT * FROM foobar WHERE (THEID + 2) != 4")).collect().size == 2) } test("SELECT COUNT(1) WHERE (predicates)") {