From 45b6f58060af6a7e2ae781bbdaff883c28f60769 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Fri, 1 Mar 2019 13:38:31 +0800 Subject: [PATCH 01/10] Add Optimize rule TransformBinaryComparison --- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../sql/catalyst/optimizer/expressions.scala | 116 +++++++++ .../TransformBinaryComparisonSuite.scala | 223 ++++++++++++++++++ 3 files changed, 341 insertions(+), 1 deletion(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransformBinaryComparisonSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index ad258986f785b..73939308cab92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -128,7 +128,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RemoveRedundantAliases, RemoveNoopOperators, SimplifyExtractValueOps, - CombineConcats) ++ + CombineConcats, + TransformBinaryComparison) ++ extendedOperatorOptimizationRules val operatorOptimizationBatch: Seq[Batch] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 39709529c00d3..4243d491d45c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -735,3 +735,119 @@ object CombineConcats extends Rule[LogicalPlan] { flattenConcats(concat) } } + + +/** + * Transform binary comparison(such as =, >, <, >=, <=) in conditions to its equivalent form, + * leaving attributes alone in one side, so that we can push it down to parquet or others. + * For example, this rule can optimize + * {{{ + * SELECT * FROM table WHERE i + 3 = 5 + * }}} + * to + * {{{ + * SELECT * FROM table WHERE i = 5 - 3 + * }}} + * when i is Int or Long, and then other rules will further optimize it to + * {{{ + * SELECT * FROM table WHERE i = 2 + * }}} + */ +object TransformBinaryComparison extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + case e @ BinaryComparison(left: BinaryArithmetic, right: Literal) + if isDataTypeSafe(left.dataType) => + transformLeft(e, left, right) + case e @ BinaryComparison(left: Literal, right: BinaryArithmetic) + if isDataTypeSafe(right.dataType) => + transformRight(e, left, right) + } + } + + private def transformLeft(bc: BinaryComparison, left: BinaryArithmetic, right: Literal) + : Expression = { + left match { + case Add(ar: AttributeReference, lit: Literal) if isOptSafe(Subtract(right, lit)) => + bc.makeCopy(Array(ar, Subtract(right, lit))) + case Add(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(right, lit)) => + bc.makeCopy(Array(ar, Subtract(right, lit))) + case Subtract(ar: AttributeReference, lit: Literal) if isOptSafe(Add(right, lit)) => + bc.makeCopy(Array(ar, Add(right, lit))) + case Subtract(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(lit, right)) => + bc.makeCopy(Array(Subtract(lit, right), ar)) + case _ => bc + } + } + + private def transformRight(bc: BinaryComparison, left: Literal, right: BinaryArithmetic) + : Expression = { + right match { + case Add(ar: AttributeReference, lit: Literal) if isOptSafe(Subtract(left, lit)) => + bc.makeCopy(Array(Subtract(left, lit), ar)) + case Add(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(left, lit)) => + bc.makeCopy(Array(Subtract(left, lit), ar)) + case Subtract(ar: AttributeReference, lit: Literal) if isOptSafe(Add(left, lit)) => + bc.makeCopy(Array(Add(left, lit), ar)) + case Subtract(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(lit, left)) => + bc.makeCopy(Array(ar, Subtract(lit, left))) + case _ => bc + } + } + + private def isDataTypeSafe(dataType: DataType): Boolean = dataType match { + case IntegerType | LongType => true + case _ => false + } + + private def isOptSafe(e: BinaryArithmetic): Boolean = { + val leftVal = e.left.eval(EmptyRow) + val rightVal = e.right.eval(EmptyRow) + + e match { + case Add(_: Literal, _: Literal) => + e.dataType match { + case IntegerType => + isAddSafe(leftVal, rightVal, Int.MinValue, Int.MaxValue) + case LongType => + isAddSafe(leftVal, rightVal, Long.MinValue, Long.MaxValue) + case _ => false + } + + case Subtract(_: Literal, _: Literal) => + e.dataType match { + case IntegerType => + isSubtractSafe(leftVal, rightVal, Int.MinValue, Int.MaxValue) + case LongType => + isSubtractSafe(leftVal, rightVal, Long.MinValue, Long.MaxValue) + case _ => false + } + + case _ => false + } + } + + private def isAddSafe[T](left: Any, right: Any, minValue: T, maxValue: T) + (implicit num: Numeric[T]): Boolean = { + import num._ + val leftVal = left.asInstanceOf[T] + val rightVal = right.asInstanceOf[T] + if (rightVal > zero) { + leftVal <= maxValue - rightVal + } else { + leftVal >= minValue - rightVal + } + } + + private def isSubtractSafe[T](left: Any, right: Any, minValue: T, maxValue: T) + (implicit num: Numeric[T]): Boolean = { + import num._ + val leftVal = left.asInstanceOf[T] + val rightVal = right.asInstanceOf[T] + if (rightVal > zero) { + leftVal >= minValue + rightVal + } else { + leftVal <= maxValue + rightVal + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransformBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransformBinaryComparisonSuite.scala new file mode 100644 index 0000000000000..913d80456669f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransformBinaryComparisonSuite.scala @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +/** + * Unit tests for transform binary comparision in expressions. + */ +class TransformBinaryComparisonSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("TransformBinaryComparison", FixedPoint(10), + ConstantFolding, + TransformBinaryComparison) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.long) + + private val columnA = 'a + private val columnB = 'b + + test("test of int: a + 2 = 8") { + val query = testRelation + .where(Add(columnA, Literal(2)) === Literal(8)) + + val correctAnswer = testRelation + .where(columnA === Literal(6)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int: a + 2 >= 8") { + val query = testRelation + .where(Add(columnA, Literal(2)) >= Literal(8)) + + val correctAnswer = testRelation + .where(columnA >= Literal(6)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int: a + 2 <= 8") { + val query = testRelation + .where(Add(columnA, Literal(2)) <= Literal(8)) + + val correctAnswer = testRelation + .where(columnA <= Literal(6)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int: a - 2 <= 8") { + val query = testRelation + .where(Subtract(columnA, Literal(2)) <= Literal(8)) + + val correctAnswer = testRelation + .where(columnA <= Literal(10)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int: 2 - a <= 8") { + val query = testRelation + .where(Subtract(Literal(2), columnA) <= Literal(8)) + + val correctAnswer = testRelation + .where(Literal(-6) <= columnA).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int: 2 - a >= 8") { + val query = testRelation + .where(Subtract(Literal(2), columnA) >= Literal(8)) + + val correctAnswer = testRelation + .where(Literal(-6) >= columnA).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int with overflow risk: a - 10 >= Int.MaxValue - 2") { + val query = testRelation + .where(Subtract(columnA, Literal(10)) >= Literal(Int.MaxValue - 2)) + + val correctAnswer = testRelation + .where(Subtract(columnA, Literal(10)) >= Literal(Int.MaxValue - 2)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int with overflow risk: 10 - a >= Int.MinValue") { + val query = testRelation + .where(Subtract(Literal(10), columnA) >= Literal(Int.MinValue)) + + val correctAnswer = testRelation + .where(Subtract(Literal(10), columnA) >= Literal(Int.MinValue)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int with overflow risk: a + 10 <= Int.MinValue + 2") { + val query = testRelation + .where(Add(columnA, Literal(10)) <= Literal(Int.MinValue + 2)) + + val correctAnswer = testRelation + .where(Add(columnA, Literal(10)) <= Literal(Int.MinValue + 2)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: b + 2L = 8L") { + val query = testRelation + .where(Add(columnB, Literal(2L)) === Literal(8L)) + + val correctAnswer = testRelation + .where(columnB === Literal(6L)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: b + 2L >= 8L") { + val query = testRelation + .where(Add(columnB, Literal(2L)) >= Literal(8L)) + + val correctAnswer = testRelation + .where(columnB >= Literal(6L)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: b + 2L <= 8L") { + val query = testRelation + .where(Add(columnB, Literal(2L)) <= Literal(8L)) + + val correctAnswer = testRelation + .where(columnB <= Literal(6L)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: b - 2L <= 8L") { + val query = testRelation + .where(Subtract(columnB, Literal(2L)) <= Literal(8L)) + + val correctAnswer = testRelation + .where(columnB <= Literal(10L)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: 2L - b <= 8L") { + val query = testRelation + .where(Subtract(Literal(2L), columnB) <= Literal(8)) + + val correctAnswer = testRelation + .where(Literal(-6L) <= columnB).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: 2L - b >= 8L") { + val query = testRelation + .where(Subtract(Literal(2L), columnB) >= Literal(8)) + + val correctAnswer = testRelation + .where(Literal(-6L) >= columnB).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long with overflow risk: b - 10L >= Long.MaxValue - 2") { + val query = testRelation + .where(Subtract(columnB, Literal(10L)) >= Literal(Long.MaxValue - 2)) + + val correctAnswer = testRelation + .where(Subtract(columnB, Literal(10L)) >= Literal(Long.MaxValue - 2)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long with overflow risk: 10L - b >= Long.MinValue") { + val query = testRelation + .where(Subtract(Literal(10L), columnB) >= Literal(Long.MinValue)) + + val correctAnswer = testRelation + .where(Subtract(Literal(10L), columnB) >= Literal(Long.MinValue)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long with overflow risk: bL + 10 <= Long.MinValue + 2") { + val query = testRelation + .where(Add(columnB, Literal(10)) <= Literal(Long.MinValue + 2)) + + val correctAnswer = testRelation + .where(Add(columnB, Literal(10L)) <= Literal(Long.MinValue + 2)).analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } +} From aaf2b8a4bb3f723b62cad6ff5009da601c74825e Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Mon, 4 Mar 2019 10:43:07 +0800 Subject: [PATCH 02/10] Rename --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- ...teArithmeticFiltersOnIntOrLongColumn.scala | 144 ++++++++++++++++++ .../sql/catalyst/optimizer/expressions.scala | 116 -------------- ...hmeticFiltersOnIntOrLongColumnSuite.scala} | 64 +++++--- 4 files changed, 187 insertions(+), 139 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{TransformBinaryComparisonSuite.scala => RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala} (85%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 73939308cab92..610725ed31767 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -129,7 +129,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RemoveNoopOperators, SimplifyExtractValueOps, CombineConcats, - TransformBinaryComparison) ++ + RewriteArithmeticFiltersOnIntOrLongColumn) ++ extendedOperatorOptimizationRules val operatorOptimizationBatch: Seq[Batch] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala new file mode 100644 index 0000000000000..bc08fc01c8e04 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.types.{DataType, IntegerType, LongType} + +/** + * Rewrite arithmetic filters on int or long column to its equivalent form, + * leaving attribute alone in one side, so that we can push it down to + * parquet or other file format. + * For example, this rule can optimize + * {{{ + * SELECT * FROM table WHERE i + 3 = 5 + * }}} + * to + * {{{ + * SELECT * FROM table WHERE i = 5 - 3 + * }}} + * when i is Int or Long, and then other rules will further optimize it to + * {{{ + * SELECT * FROM table WHERE i = 2 + * }}} + */ +object RewriteArithmeticFiltersOnIntOrLongColumn extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => + q transformExpressionsUp { + case e @ BinaryComparison(left: BinaryArithmetic, right: Literal) + if isDataTypeSafe(left.dataType) => + transformLeft(e, left, right) + case e @ BinaryComparison(left: Literal, right: BinaryArithmetic) + if isDataTypeSafe(right.dataType) => + transformRight(e, left, right) + } + } + + private def transformLeft( + bc: BinaryComparison, + left: BinaryArithmetic, + right: Literal): Expression = { + left match { + case Add(ar: AttributeReference, lit: Literal) if isOptSafe(Subtract(right, lit)) => + bc.makeCopy(Array(ar, Subtract(right, lit))) + case Add(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(right, lit)) => + bc.makeCopy(Array(ar, Subtract(right, lit))) + case Subtract(ar: AttributeReference, lit: Literal) if isOptSafe(Add(right, lit)) => + bc.makeCopy(Array(ar, Add(right, lit))) + case Subtract(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(lit, right)) => + bc.makeCopy(Array(Subtract(lit, right), ar)) + case _ => bc + } + } + + private def transformRight( + bc: BinaryComparison, + left: Literal, + right: BinaryArithmetic): Expression = { + right match { + case Add(ar: AttributeReference, lit: Literal) if isOptSafe(Subtract(left, lit)) => + bc.makeCopy(Array(Subtract(left, lit), ar)) + case Add(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(left, lit)) => + bc.makeCopy(Array(Subtract(left, lit), ar)) + case Subtract(ar: AttributeReference, lit: Literal) if isOptSafe(Add(left, lit)) => + bc.makeCopy(Array(Add(left, lit), ar)) + case Subtract(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(lit, left)) => + bc.makeCopy(Array(ar, Subtract(lit, left))) + case _ => bc + } + } + + private def isDataTypeSafe(dataType: DataType): Boolean = dataType match { + case IntegerType | LongType => true + case _ => false + } + + private def isOptSafe(e: BinaryArithmetic): Boolean = { + val leftVal = e.left.eval(EmptyRow) + val rightVal = e.right.eval(EmptyRow) + + e match { + case Add(_: Literal, _: Literal) => + e.dataType match { + case IntegerType => + isAddSafe(leftVal, rightVal, Int.MinValue, Int.MaxValue) + case LongType => + isAddSafe(leftVal, rightVal, Long.MinValue, Long.MaxValue) + case _ => false + } + + case Subtract(_: Literal, _: Literal) => + e.dataType match { + case IntegerType => + isSubtractSafe(leftVal, rightVal, Int.MinValue, Int.MaxValue) + case LongType => + isSubtractSafe(leftVal, rightVal, Long.MinValue, Long.MaxValue) + case _ => false + } + + case _ => false + } + } + + private def isAddSafe[T](left: Any, right: Any, minValue: T, maxValue: T)( + implicit num: Numeric[T]): Boolean = { + import num._ + val leftVal = left.asInstanceOf[T] + val rightVal = right.asInstanceOf[T] + if (rightVal > zero) { + leftVal <= maxValue - rightVal + } else { + leftVal >= minValue - rightVal + } + } + + private def isSubtractSafe[T](left: Any, right: Any, minValue: T, maxValue: T)( + implicit num: Numeric[T]): Boolean = { + import num._ + val leftVal = left.asInstanceOf[T] + val rightVal = right.asInstanceOf[T] + if (rightVal > zero) { + leftVal >= minValue + rightVal + } else { + leftVal <= maxValue + rightVal + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 4243d491d45c3..39709529c00d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -735,119 +735,3 @@ object CombineConcats extends Rule[LogicalPlan] { flattenConcats(concat) } } - - -/** - * Transform binary comparison(such as =, >, <, >=, <=) in conditions to its equivalent form, - * leaving attributes alone in one side, so that we can push it down to parquet or others. - * For example, this rule can optimize - * {{{ - * SELECT * FROM table WHERE i + 3 = 5 - * }}} - * to - * {{{ - * SELECT * FROM table WHERE i = 5 - 3 - * }}} - * when i is Int or Long, and then other rules will further optimize it to - * {{{ - * SELECT * FROM table WHERE i = 2 - * }}} - */ -object TransformBinaryComparison extends Rule[LogicalPlan] with PredicateHelper { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => q transformExpressionsUp { - case e @ BinaryComparison(left: BinaryArithmetic, right: Literal) - if isDataTypeSafe(left.dataType) => - transformLeft(e, left, right) - case e @ BinaryComparison(left: Literal, right: BinaryArithmetic) - if isDataTypeSafe(right.dataType) => - transformRight(e, left, right) - } - } - - private def transformLeft(bc: BinaryComparison, left: BinaryArithmetic, right: Literal) - : Expression = { - left match { - case Add(ar: AttributeReference, lit: Literal) if isOptSafe(Subtract(right, lit)) => - bc.makeCopy(Array(ar, Subtract(right, lit))) - case Add(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(right, lit)) => - bc.makeCopy(Array(ar, Subtract(right, lit))) - case Subtract(ar: AttributeReference, lit: Literal) if isOptSafe(Add(right, lit)) => - bc.makeCopy(Array(ar, Add(right, lit))) - case Subtract(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(lit, right)) => - bc.makeCopy(Array(Subtract(lit, right), ar)) - case _ => bc - } - } - - private def transformRight(bc: BinaryComparison, left: Literal, right: BinaryArithmetic) - : Expression = { - right match { - case Add(ar: AttributeReference, lit: Literal) if isOptSafe(Subtract(left, lit)) => - bc.makeCopy(Array(Subtract(left, lit), ar)) - case Add(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(left, lit)) => - bc.makeCopy(Array(Subtract(left, lit), ar)) - case Subtract(ar: AttributeReference, lit: Literal) if isOptSafe(Add(left, lit)) => - bc.makeCopy(Array(Add(left, lit), ar)) - case Subtract(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(lit, left)) => - bc.makeCopy(Array(ar, Subtract(lit, left))) - case _ => bc - } - } - - private def isDataTypeSafe(dataType: DataType): Boolean = dataType match { - case IntegerType | LongType => true - case _ => false - } - - private def isOptSafe(e: BinaryArithmetic): Boolean = { - val leftVal = e.left.eval(EmptyRow) - val rightVal = e.right.eval(EmptyRow) - - e match { - case Add(_: Literal, _: Literal) => - e.dataType match { - case IntegerType => - isAddSafe(leftVal, rightVal, Int.MinValue, Int.MaxValue) - case LongType => - isAddSafe(leftVal, rightVal, Long.MinValue, Long.MaxValue) - case _ => false - } - - case Subtract(_: Literal, _: Literal) => - e.dataType match { - case IntegerType => - isSubtractSafe(leftVal, rightVal, Int.MinValue, Int.MaxValue) - case LongType => - isSubtractSafe(leftVal, rightVal, Long.MinValue, Long.MaxValue) - case _ => false - } - - case _ => false - } - } - - private def isAddSafe[T](left: Any, right: Any, minValue: T, maxValue: T) - (implicit num: Numeric[T]): Boolean = { - import num._ - val leftVal = left.asInstanceOf[T] - val rightVal = right.asInstanceOf[T] - if (rightVal > zero) { - leftVal <= maxValue - rightVal - } else { - leftVal >= minValue - rightVal - } - } - - private def isSubtractSafe[T](left: Any, right: Any, minValue: T, maxValue: T) - (implicit num: Numeric[T]): Boolean = { - import num._ - val leftVal = left.asInstanceOf[T] - val rightVal = right.asInstanceOf[T] - if (rightVal > zero) { - leftVal >= minValue + rightVal - } else { - leftVal <= maxValue + rightVal - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransformBinaryComparisonSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala similarity index 85% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransformBinaryComparisonSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala index 913d80456669f..1b651afff6101 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TransformBinaryComparisonSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala @@ -25,15 +25,17 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor /** - * Unit tests for transform binary comparision in expressions. + * Unit tests for rewrite arithmetic filters on int or long column optimizer. */ -class TransformBinaryComparisonSuite extends PlanTest { +class RewriteArithmeticFiltersOnIntOrLongColumnSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = - Batch("TransformBinaryComparison", FixedPoint(10), + Batch( + "RewriteArithmeticFiltersOnIntOrLongColumn", + FixedPoint(10), ConstantFolding, - TransformBinaryComparison) :: Nil + RewriteArithmeticFiltersOnIntOrLongColumn) :: Nil } val testRelation = LocalRelation('a.int, 'b.long) @@ -46,7 +48,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Add(columnA, Literal(2)) === Literal(8)) val correctAnswer = testRelation - .where(columnA === Literal(6)).analyze + .where(columnA === Literal(6)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -56,7 +59,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Add(columnA, Literal(2)) >= Literal(8)) val correctAnswer = testRelation - .where(columnA >= Literal(6)).analyze + .where(columnA >= Literal(6)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -66,7 +70,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Add(columnA, Literal(2)) <= Literal(8)) val correctAnswer = testRelation - .where(columnA <= Literal(6)).analyze + .where(columnA <= Literal(6)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -76,7 +81,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Subtract(columnA, Literal(2)) <= Literal(8)) val correctAnswer = testRelation - .where(columnA <= Literal(10)).analyze + .where(columnA <= Literal(10)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -86,7 +92,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Subtract(Literal(2), columnA) <= Literal(8)) val correctAnswer = testRelation - .where(Literal(-6) <= columnA).analyze + .where(Literal(-6) <= columnA) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -96,7 +103,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Subtract(Literal(2), columnA) >= Literal(8)) val correctAnswer = testRelation - .where(Literal(-6) >= columnA).analyze + .where(Literal(-6) >= columnA) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -106,7 +114,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Subtract(columnA, Literal(10)) >= Literal(Int.MaxValue - 2)) val correctAnswer = testRelation - .where(Subtract(columnA, Literal(10)) >= Literal(Int.MaxValue - 2)).analyze + .where(Subtract(columnA, Literal(10)) >= Literal(Int.MaxValue - 2)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -116,7 +125,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Subtract(Literal(10), columnA) >= Literal(Int.MinValue)) val correctAnswer = testRelation - .where(Subtract(Literal(10), columnA) >= Literal(Int.MinValue)).analyze + .where(Subtract(Literal(10), columnA) >= Literal(Int.MinValue)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -126,7 +136,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Add(columnA, Literal(10)) <= Literal(Int.MinValue + 2)) val correctAnswer = testRelation - .where(Add(columnA, Literal(10)) <= Literal(Int.MinValue + 2)).analyze + .where(Add(columnA, Literal(10)) <= Literal(Int.MinValue + 2)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -136,7 +147,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Add(columnB, Literal(2L)) === Literal(8L)) val correctAnswer = testRelation - .where(columnB === Literal(6L)).analyze + .where(columnB === Literal(6L)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -146,7 +158,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Add(columnB, Literal(2L)) >= Literal(8L)) val correctAnswer = testRelation - .where(columnB >= Literal(6L)).analyze + .where(columnB >= Literal(6L)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -156,7 +169,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Add(columnB, Literal(2L)) <= Literal(8L)) val correctAnswer = testRelation - .where(columnB <= Literal(6L)).analyze + .where(columnB <= Literal(6L)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -166,7 +180,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Subtract(columnB, Literal(2L)) <= Literal(8L)) val correctAnswer = testRelation - .where(columnB <= Literal(10L)).analyze + .where(columnB <= Literal(10L)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -176,7 +191,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Subtract(Literal(2L), columnB) <= Literal(8)) val correctAnswer = testRelation - .where(Literal(-6L) <= columnB).analyze + .where(Literal(-6L) <= columnB) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -186,7 +202,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Subtract(Literal(2L), columnB) >= Literal(8)) val correctAnswer = testRelation - .where(Literal(-6L) >= columnB).analyze + .where(Literal(-6L) >= columnB) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -196,7 +213,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Subtract(columnB, Literal(10L)) >= Literal(Long.MaxValue - 2)) val correctAnswer = testRelation - .where(Subtract(columnB, Literal(10L)) >= Literal(Long.MaxValue - 2)).analyze + .where(Subtract(columnB, Literal(10L)) >= Literal(Long.MaxValue - 2)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -206,7 +224,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Subtract(Literal(10L), columnB) >= Literal(Long.MinValue)) val correctAnswer = testRelation - .where(Subtract(Literal(10L), columnB) >= Literal(Long.MinValue)).analyze + .where(Subtract(Literal(10L), columnB) >= Literal(Long.MinValue)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } @@ -216,7 +235,8 @@ class TransformBinaryComparisonSuite extends PlanTest { .where(Add(columnB, Literal(10)) <= Literal(Long.MinValue + 2)) val correctAnswer = testRelation - .where(Add(columnB, Literal(10L)) <= Literal(Long.MinValue + 2)).analyze + .where(Add(columnB, Literal(10L)) <= Literal(Long.MinValue + 2)) + .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } From 82ff2a10335b5f1d06621eb386d33af517dd8640 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Mon, 4 Mar 2019 19:05:59 +0800 Subject: [PATCH 03/10] Change the order in optimizer --- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 ++-- .../RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 610725ed31767..d2007c97cec43 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -111,6 +111,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) ConstantPropagation, FoldablePropagation, OptimizeIn, + RewriteArithmeticFiltersOnIntOrLongColumn, ConstantFolding, ReorderAssociativeOperator, LikeSimplification, @@ -128,8 +129,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) RemoveRedundantAliases, RemoveNoopOperators, SimplifyExtractValueOps, - CombineConcats, - RewriteArithmeticFiltersOnIntOrLongColumn) ++ + CombineConcats) ++ extendedOperatorOptimizationRules val operatorOptimizationBatch: Seq[Batch] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala index 1b651afff6101..2c7c829a3467e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala @@ -34,8 +34,8 @@ class RewriteArithmeticFiltersOnIntOrLongColumnSuite extends PlanTest { Batch( "RewriteArithmeticFiltersOnIntOrLongColumn", FixedPoint(10), - ConstantFolding, - RewriteArithmeticFiltersOnIntOrLongColumn) :: Nil + RewriteArithmeticFiltersOnIntOrLongColumn, + ConstantFolding) :: Nil } val testRelation = LocalRelation('a.int, 'b.long) From 597d6d7a14c10fd94973c355747cdec346d4beb2 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Tue, 5 Mar 2019 08:07:35 +0800 Subject: [PATCH 04/10] Fix ut failure --- .../src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 5 +++-- .../spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index a4dc537d31b7e..d401de14b1d14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -344,8 +344,9 @@ class JDBCSuite extends QueryTest "WHERE (THEID > 0 AND TRIM(NAME) = 'mary') OR (NAME = 'fred')") assert(df2.collect.toSet === Set(Row("fred", 1), Row("mary", 2))) - assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) - assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 2) != 4")).collect().size == 2) + // SPARK-27033: Add Optimize rule RewriteArithmeticFiltersOnIntOrLongColumn + assert(checkPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) + assert(checkPushdown(sql("SELECT * FROM foobar WHERE (THEID + 2) != 4")).collect().size == 2) } test("SELECT COUNT(1) WHERE (predicates)") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala index 1e525c46a9cfb..adb06813f8092 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -65,11 +65,11 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto // verify the matching partitions val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr, - Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]), + Project(Seq(($"part" * 1).as("x").expr.asInstanceOf[NamedExpression]), spark.table("metadata_only").logicalPlan.asInstanceOf[SubqueryAlias].child))) .queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType)))) - checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x")) + checkAnswer(partitions, Seq(0, 1, 2, 3, 4).toDF("x")) // verify that the partition predicate was not pushed down to the metastore assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11) From 03c522e455938bcc7af2bf10fa0ad4b1ea785e70 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Tue, 5 Mar 2019 18:07:12 +0800 Subject: [PATCH 05/10] Filter only --- .../RewriteArithmeticFiltersOnIntOrLongColumn.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala index bc08fc01c8e04..8d95f5f189a57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.{DataType, IntegerType, LongType} @@ -41,8 +41,8 @@ import org.apache.spark.sql.types.{DataType, IntegerType, LongType} */ object RewriteArithmeticFiltersOnIntOrLongColumn extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case q: LogicalPlan => - q transformExpressionsUp { + case f: Filter => + f transformExpressionsUp { case e @ BinaryComparison(left: BinaryArithmetic, right: Literal) if isDataTypeSafe(left.dataType) => transformLeft(e, left, right) From 0f61953314ec9d4d87a8fd33c01baa262103608d Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Wed, 6 Mar 2019 19:42:51 +0800 Subject: [PATCH 06/10] Update doc and add one more test case --- .../RewriteArithmeticFiltersOnIntOrLongColumn.scala | 3 +++ ...writeArithmeticFiltersOnIntOrLongColumnSuite.scala | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala index 8d95f5f189a57..d227bb0f6a6eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala @@ -38,6 +38,9 @@ import org.apache.spark.sql.types.{DataType, IntegerType, LongType} * {{{ * SELECT * FROM table WHERE i = 2 * }}} + * The arithmetic operation supports `Add` and `Subtract`. The comparision supports + * '=', '>=', '<=', '>', '<', '!='. It only supports type of `INT` and `LONG`, + * it doesn't support `FLOAT` or `DOUBLE` for precision issues. */ object RewriteArithmeticFiltersOnIntOrLongColumn extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala index 2c7c829a3467e..f456a08af1ffb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala @@ -240,4 +240,15 @@ class RewriteArithmeticFiltersOnIntOrLongColumnSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), correctAnswer) } + + test("test of int: 2 - a != 8") { + val query = testRelation + .where(Subtract(Literal(2), columnA) != Literal(8)) + + val correctAnswer = testRelation + .where(Literal(-6) != columnA) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } } From 5d0a5e8f688c8c35861181b76fb8b17ee63b6650 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Wed, 6 Mar 2019 23:21:49 +0800 Subject: [PATCH 07/10] Change literal to foldable --- ...teArithmeticFiltersOnIntOrLongColumn.scala | 48 +++++++++---------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala index d227bb0f6a6eb..f2473d013af48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala @@ -46,11 +46,11 @@ object RewriteArithmeticFiltersOnIntOrLongColumn extends Rule[LogicalPlan] with def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f: Filter => f transformExpressionsUp { - case e @ BinaryComparison(left: BinaryArithmetic, right: Literal) - if isDataTypeSafe(left.dataType) => + case e @ BinaryComparison(left: BinaryArithmetic, right: Expression) + if right.foldable && isDataTypeSafe(left.dataType) => transformLeft(e, left, right) - case e @ BinaryComparison(left: Literal, right: BinaryArithmetic) - if isDataTypeSafe(right.dataType) => + case e @ BinaryComparison(left: Expression, right: BinaryArithmetic) + if left.foldable && isDataTypeSafe(right.dataType) => transformRight(e, left, right) } } @@ -58,33 +58,33 @@ object RewriteArithmeticFiltersOnIntOrLongColumn extends Rule[LogicalPlan] with private def transformLeft( bc: BinaryComparison, left: BinaryArithmetic, - right: Literal): Expression = { + right: Expression): Expression = { left match { - case Add(ar: AttributeReference, lit: Literal) if isOptSafe(Subtract(right, lit)) => - bc.makeCopy(Array(ar, Subtract(right, lit))) - case Add(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(right, lit)) => - bc.makeCopy(Array(ar, Subtract(right, lit))) - case Subtract(ar: AttributeReference, lit: Literal) if isOptSafe(Add(right, lit)) => - bc.makeCopy(Array(ar, Add(right, lit))) - case Subtract(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(lit, right)) => - bc.makeCopy(Array(Subtract(lit, right), ar)) + case Add(ar: AttributeReference, e) if e.foldable && isOptSafe(Subtract(right, e)) => + bc.makeCopy(Array(ar, Subtract(right, e))) + case Add(e, ar: AttributeReference) if e.foldable && isOptSafe(Subtract(right, e)) => + bc.makeCopy(Array(ar, Subtract(right, e))) + case Subtract(ar: AttributeReference, e) if e.foldable && isOptSafe(Add(right, e)) => + bc.makeCopy(Array(ar, Add(right, e))) + case Subtract(e, ar: AttributeReference) if e.foldable && isOptSafe(Subtract(e, right)) => + bc.makeCopy(Array(Subtract(e, right), ar)) case _ => bc } } private def transformRight( bc: BinaryComparison, - left: Literal, + left: Expression, right: BinaryArithmetic): Expression = { right match { - case Add(ar: AttributeReference, lit: Literal) if isOptSafe(Subtract(left, lit)) => - bc.makeCopy(Array(Subtract(left, lit), ar)) - case Add(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(left, lit)) => - bc.makeCopy(Array(Subtract(left, lit), ar)) - case Subtract(ar: AttributeReference, lit: Literal) if isOptSafe(Add(left, lit)) => - bc.makeCopy(Array(Add(left, lit), ar)) - case Subtract(lit: Literal, ar: AttributeReference) if isOptSafe(Subtract(lit, left)) => - bc.makeCopy(Array(ar, Subtract(lit, left))) + case Add(ar: AttributeReference, e) if e.foldable && isOptSafe(Subtract(left, e)) => + bc.makeCopy(Array(Subtract(left, e), ar)) + case Add(e, ar: AttributeReference) if e.foldable && isOptSafe(Subtract(left, e)) => + bc.makeCopy(Array(Subtract(left, e), ar)) + case Subtract(ar: AttributeReference, e) if e.foldable && isOptSafe(Add(left, e)) => + bc.makeCopy(Array(Add(left, e), ar)) + case Subtract(e, ar: AttributeReference) if e.foldable && isOptSafe(Subtract(e, left)) => + bc.makeCopy(Array(ar, Subtract(e, left))) case _ => bc } } @@ -99,7 +99,7 @@ object RewriteArithmeticFiltersOnIntOrLongColumn extends Rule[LogicalPlan] with val rightVal = e.right.eval(EmptyRow) e match { - case Add(_: Literal, _: Literal) => + case Add(_, _) => e.dataType match { case IntegerType => isAddSafe(leftVal, rightVal, Int.MinValue, Int.MaxValue) @@ -108,7 +108,7 @@ object RewriteArithmeticFiltersOnIntOrLongColumn extends Rule[LogicalPlan] with case _ => false } - case Subtract(_: Literal, _: Literal) => + case Subtract(_, _) => e.dataType match { case IntegerType => isSubtractSafe(leftVal, rightVal, Int.MinValue, Int.MaxValue) From 3927decc5182e2007d303457800cf54a8b1c69ed Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Fri, 8 Mar 2019 00:03:33 +0800 Subject: [PATCH 08/10] Add supports for ShortType and ByteType --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- ...teArithmeticFiltersOnIntegralColumn.scala} | 42 +- ...thmeticFiltersOnIntOrLongColumnSuite.scala | 254 --------- ...ithmeticFiltersOnIntegralColumnSuite.scala | 487 ++++++++++++++++++ .../OptimizeHiveMetadataOnlyQuerySuite.scala | 11 +- 5 files changed, 521 insertions(+), 275 deletions(-) rename sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/{RewriteArithmeticFiltersOnIntOrLongColumn.scala => RewriteArithmeticFiltersOnIntegralColumn.scala} (77%) delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumnSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d2007c97cec43..65c3064531b35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -111,7 +111,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) ConstantPropagation, FoldablePropagation, OptimizeIn, - RewriteArithmeticFiltersOnIntOrLongColumn, + RewriteArithmeticFiltersOnIntegralColumn, ConstantFolding, ReorderAssociativeOperator, LikeSimplification, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumn.scala similarity index 77% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumn.scala index f2473d013af48..f575d5d01e544 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumn.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumn.scala @@ -20,29 +20,31 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.types.{DataType, IntegerType, LongType} +import org.apache.spark.sql.types._ /** - * Rewrite arithmetic filters on int or long column to its equivalent form, - * leaving attribute alone in one side, so that we can push it down to - * parquet or other file format. - * For example, this rule can optimize + * Rewrite arithmetic filters on an integral-type (e.g., byte, short, int and long) + * column to its equivalent form, leaving attribute alone in a left side, so that + * we can push it down to datasources (e.g., Parquet and ORC). + * + * For example, this rule can optimize a query as follows: * {{{ * SELECT * FROM table WHERE i + 3 = 5 + * ==> SELECT * FROM table WHERE i = 5 - 3 * }}} - * to - * {{{ - * SELECT * FROM table WHERE i = 5 - 3 - * }}} - * when i is Int or Long, and then other rules will further optimize it to + * + * Then, the [[ConstantFolding]] rule will further optimize it as follows: * {{{ * SELECT * FROM table WHERE i = 2 * }}} - * The arithmetic operation supports `Add` and `Subtract`. The comparision supports - * '=', '>=', '<=', '>', '<', '!='. It only supports type of `INT` and `LONG`, - * it doesn't support `FLOAT` or `DOUBLE` for precision issues. - */ -object RewriteArithmeticFiltersOnIntOrLongColumn extends Rule[LogicalPlan] with PredicateHelper { + * + * Note: + * 1. This rule supports `Add` and `Subtract` in arithmetic expressions. + * 2. This rule supports `=`, `>=`, `<=`, `>`, `<`, and `!=` in comparators. + * 3. This rule supports integral-type (`byte`, `short`, `int`, `long`) only. + * It doesn't support `float` or `double` because of precision issues. + */ +object RewriteArithmeticFiltersOnIntegralColumn extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f: Filter => f transformExpressionsUp { @@ -90,7 +92,7 @@ object RewriteArithmeticFiltersOnIntOrLongColumn extends Rule[LogicalPlan] with } private def isDataTypeSafe(dataType: DataType): Boolean = dataType match { - case IntegerType | LongType => true + case ByteType | ShortType | IntegerType | LongType => true case _ => false } @@ -101,6 +103,10 @@ object RewriteArithmeticFiltersOnIntOrLongColumn extends Rule[LogicalPlan] with e match { case Add(_, _) => e.dataType match { + case ByteType => + isAddSafe(leftVal, rightVal, Byte.MinValue, Byte.MaxValue) + case ShortType => + isAddSafe(leftVal, rightVal, Short.MinValue, Short.MaxValue) case IntegerType => isAddSafe(leftVal, rightVal, Int.MinValue, Int.MaxValue) case LongType => @@ -110,6 +116,10 @@ object RewriteArithmeticFiltersOnIntOrLongColumn extends Rule[LogicalPlan] with case Subtract(_, _) => e.dataType match { + case ByteType => + isSubtractSafe(leftVal, rightVal, Byte.MinValue, Byte.MaxValue) + case ShortType => + isSubtractSafe(leftVal, rightVal, Short.MinValue, Short.MaxValue) case IntegerType => isSubtractSafe(leftVal, rightVal, Int.MinValue, Int.MaxValue) case LongType => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala deleted file mode 100644 index f456a08af1ffb..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntOrLongColumnSuite.scala +++ /dev/null @@ -1,254 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.optimizer - -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} -import org.apache.spark.sql.catalyst.rules.RuleExecutor - -/** - * Unit tests for rewrite arithmetic filters on int or long column optimizer. - */ -class RewriteArithmeticFiltersOnIntOrLongColumnSuite extends PlanTest { - - object Optimize extends RuleExecutor[LogicalPlan] { - val batches = - Batch( - "RewriteArithmeticFiltersOnIntOrLongColumn", - FixedPoint(10), - RewriteArithmeticFiltersOnIntOrLongColumn, - ConstantFolding) :: Nil - } - - val testRelation = LocalRelation('a.int, 'b.long) - - private val columnA = 'a - private val columnB = 'b - - test("test of int: a + 2 = 8") { - val query = testRelation - .where(Add(columnA, Literal(2)) === Literal(8)) - - val correctAnswer = testRelation - .where(columnA === Literal(6)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int: a + 2 >= 8") { - val query = testRelation - .where(Add(columnA, Literal(2)) >= Literal(8)) - - val correctAnswer = testRelation - .where(columnA >= Literal(6)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int: a + 2 <= 8") { - val query = testRelation - .where(Add(columnA, Literal(2)) <= Literal(8)) - - val correctAnswer = testRelation - .where(columnA <= Literal(6)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int: a - 2 <= 8") { - val query = testRelation - .where(Subtract(columnA, Literal(2)) <= Literal(8)) - - val correctAnswer = testRelation - .where(columnA <= Literal(10)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int: 2 - a <= 8") { - val query = testRelation - .where(Subtract(Literal(2), columnA) <= Literal(8)) - - val correctAnswer = testRelation - .where(Literal(-6) <= columnA) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int: 2 - a >= 8") { - val query = testRelation - .where(Subtract(Literal(2), columnA) >= Literal(8)) - - val correctAnswer = testRelation - .where(Literal(-6) >= columnA) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int with overflow risk: a - 10 >= Int.MaxValue - 2") { - val query = testRelation - .where(Subtract(columnA, Literal(10)) >= Literal(Int.MaxValue - 2)) - - val correctAnswer = testRelation - .where(Subtract(columnA, Literal(10)) >= Literal(Int.MaxValue - 2)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int with overflow risk: 10 - a >= Int.MinValue") { - val query = testRelation - .where(Subtract(Literal(10), columnA) >= Literal(Int.MinValue)) - - val correctAnswer = testRelation - .where(Subtract(Literal(10), columnA) >= Literal(Int.MinValue)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int with overflow risk: a + 10 <= Int.MinValue + 2") { - val query = testRelation - .where(Add(columnA, Literal(10)) <= Literal(Int.MinValue + 2)) - - val correctAnswer = testRelation - .where(Add(columnA, Literal(10)) <= Literal(Int.MinValue + 2)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long: b + 2L = 8L") { - val query = testRelation - .where(Add(columnB, Literal(2L)) === Literal(8L)) - - val correctAnswer = testRelation - .where(columnB === Literal(6L)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long: b + 2L >= 8L") { - val query = testRelation - .where(Add(columnB, Literal(2L)) >= Literal(8L)) - - val correctAnswer = testRelation - .where(columnB >= Literal(6L)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long: b + 2L <= 8L") { - val query = testRelation - .where(Add(columnB, Literal(2L)) <= Literal(8L)) - - val correctAnswer = testRelation - .where(columnB <= Literal(6L)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long: b - 2L <= 8L") { - val query = testRelation - .where(Subtract(columnB, Literal(2L)) <= Literal(8L)) - - val correctAnswer = testRelation - .where(columnB <= Literal(10L)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long: 2L - b <= 8L") { - val query = testRelation - .where(Subtract(Literal(2L), columnB) <= Literal(8)) - - val correctAnswer = testRelation - .where(Literal(-6L) <= columnB) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long: 2L - b >= 8L") { - val query = testRelation - .where(Subtract(Literal(2L), columnB) >= Literal(8)) - - val correctAnswer = testRelation - .where(Literal(-6L) >= columnB) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long with overflow risk: b - 10L >= Long.MaxValue - 2") { - val query = testRelation - .where(Subtract(columnB, Literal(10L)) >= Literal(Long.MaxValue - 2)) - - val correctAnswer = testRelation - .where(Subtract(columnB, Literal(10L)) >= Literal(Long.MaxValue - 2)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long with overflow risk: 10L - b >= Long.MinValue") { - val query = testRelation - .where(Subtract(Literal(10L), columnB) >= Literal(Long.MinValue)) - - val correctAnswer = testRelation - .where(Subtract(Literal(10L), columnB) >= Literal(Long.MinValue)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long with overflow risk: bL + 10 <= Long.MinValue + 2") { - val query = testRelation - .where(Add(columnB, Literal(10)) <= Literal(Long.MinValue + 2)) - - val correctAnswer = testRelation - .where(Add(columnB, Literal(10L)) <= Literal(Long.MinValue + 2)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int: 2 - a != 8") { - val query = testRelation - .where(Subtract(Literal(2), columnA) != Literal(8)) - - val correctAnswer = testRelation - .where(Literal(-6) != columnA) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumnSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumnSuite.scala new file mode 100644 index 0000000000000..9c44c6c1d82ef --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumnSuite.scala @@ -0,0 +1,487 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor + +/** + * Unit tests for rewrite arithmetic filters on integral column optimizer. + */ +class RewriteArithmeticFiltersOnIntegralColumnSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch( + "RewriteArithmeticFiltersOnIntegralColumn", + FixedPoint(10), + RewriteArithmeticFiltersOnIntegralColumn, + ConstantFolding) :: Nil + } + + val testRelation = LocalRelation('a.int, 'b.long, 'c.byte, 'd.short) + + private val columnA = 'a + private val columnB = 'b + private val columnC = 'c + private val columnD = 'd + + test("test of int: a + 2 = 8") { + val query = testRelation + .where(Add(columnA, Literal(2)) === Literal(8)) + + val correctAnswer = testRelation + .where(columnA === Literal(6)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int: a + 2 >= 8") { + val query = testRelation + .where(Add(columnA, Literal(2)) >= Literal(8)) + + val correctAnswer = testRelation + .where(columnA >= Literal(6)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int: a + 2 <= 8") { + val query = testRelation + .where(Add(columnA, Literal(2)) <= Literal(8)) + + val correctAnswer = testRelation + .where(columnA <= Literal(6)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int: a - 2 <= 8") { + val query = testRelation + .where(Subtract(columnA, Literal(2)) <= Literal(8)) + + val correctAnswer = testRelation + .where(columnA <= Literal(10)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int: 2 - a <= 8") { + val query = testRelation + .where(Subtract(Literal(2), columnA) <= Literal(8)) + + val correctAnswer = testRelation + .where(Literal(-6) <= columnA) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int: 2 - a >= 8") { + val query = testRelation + .where(Subtract(Literal(2), columnA) >= Literal(8)) + + val correctAnswer = testRelation + .where(Literal(-6) >= columnA) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int: 2 - a =!= 8") { + val query = testRelation + .where(Subtract(Literal(2), columnA) =!= Literal(8)) + + val correctAnswer = testRelation + .where(Literal(-6) =!= columnA) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int with overflow risk: a - 10 >= Int.MaxValue - 2") { + val query = testRelation + .where(Subtract(columnA, Literal(10)) >= Literal(Int.MaxValue - 2)) + + val correctAnswer = testRelation + .where(Subtract(columnA, Literal(10)) >= Literal(Int.MaxValue - 2)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int with overflow risk: 10 - a >= Int.MinValue") { + val query = testRelation + .where(Subtract(Literal(10), columnA) >= Literal(Int.MinValue)) + + val correctAnswer = testRelation + .where(Subtract(Literal(10), columnA) >= Literal(Int.MinValue)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of int with overflow risk: a + 10 <= Int.MinValue + 2") { + val query = testRelation + .where(Add(columnA, Literal(10)) <= Literal(Int.MinValue + 2)) + + val correctAnswer = testRelation + .where(Add(columnA, Literal(10)) <= Literal(Int.MinValue + 2)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: b + 2L = 8L") { + val query = testRelation + .where(Add(columnB, Literal(2L)) === Literal(8L)) + + val correctAnswer = testRelation + .where(columnB === Literal(6L)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: b + 2L >= 8L") { + val query = testRelation + .where(Add(columnB, Literal(2L)) >= Literal(8L)) + + val correctAnswer = testRelation + .where(columnB >= Literal(6L)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: b + 2L <= 8L") { + val query = testRelation + .where(Add(columnB, Literal(2L)) <= Literal(8L)) + + val correctAnswer = testRelation + .where(columnB <= Literal(6L)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: b - 2L <= 8L") { + val query = testRelation + .where(Subtract(columnB, Literal(2L)) <= Literal(8L)) + + val correctAnswer = testRelation + .where(columnB <= Literal(10L)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: 2L - b <= 8L") { + val query = testRelation + .where(Subtract(Literal(2L), columnB) <= Literal(8)) + + val correctAnswer = testRelation + .where(Literal(-6L) <= columnB) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: 2L - b >= 8L") { + val query = testRelation + .where(Subtract(Literal(2L), columnB) >= Literal(8)) + + val correctAnswer = testRelation + .where(Literal(-6L) >= columnB) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long: 2L - b =!= 8L") { + val query = testRelation + .where(Subtract(Literal(2L), columnB) =!= Literal(8L)) + + val correctAnswer = testRelation + .where(Literal(-6L) =!= columnB) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long with overflow risk: b - 10L >= Long.MaxValue - 2") { + val query = testRelation + .where(Subtract(columnB, Literal(10L)) >= Literal(Long.MaxValue - 2)) + + val correctAnswer = testRelation + .where(Subtract(columnB, Literal(10L)) >= Literal(Long.MaxValue - 2)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long with overflow risk: 10L - b >= Long.MinValue") { + val query = testRelation + .where(Subtract(Literal(10L), columnB) >= Literal(Long.MinValue)) + + val correctAnswer = testRelation + .where(Subtract(Literal(10L), columnB) >= Literal(Long.MinValue)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of long with overflow risk: bL + 10 <= Long.MinValue + 2") { + val query = testRelation + .where(Add(columnB, Literal(10L)) <= Literal(Long.MinValue + 2)) + + val correctAnswer = testRelation + .where(Add(columnB, Literal(10L)) <= Literal(Long.MinValue + 2)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte: c + 2 = 8") { + val query = testRelation + .where(Add(columnC, Literal(2.toByte)) === Literal(8.toByte)) + + val correctAnswer = testRelation + .where(columnC === Literal(6.toByte)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte: c + 2 >= 8") { + val query = testRelation + .where(Add(columnC, Literal(2.toByte)) >= Literal(8.toByte)) + + val correctAnswer = testRelation + .where(columnC >= Literal(6.toByte)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte: c + 2 <= 8") { + val query = testRelation + .where(Add(columnC, Literal(2.toByte)) <= Literal(8.toByte)) + + val correctAnswer = testRelation + .where(columnC <= Literal(6.toByte)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte: c - 2 <= 8") { + val query = testRelation + .where(Subtract(columnC, Literal(2.toByte)) <= Literal(8.toByte)) + + val correctAnswer = testRelation + .where(columnC <= Literal(10.toByte)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte: 2 - c <= 8") { + val query = testRelation + .where(Subtract(Literal(2.toByte), columnC) <= Literal(8.toByte)) + + val correctAnswer = testRelation + .where(Literal(-6.toByte) <= columnC) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte: 2 - c >= 8") { + val query = testRelation + .where(Subtract(Literal(2.toByte), columnC) >= Literal(8.toByte)) + + val correctAnswer = testRelation + .where(Literal(-6.toByte) >= columnC) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte: 2 - c =!= 8") { + val query = testRelation + .where(Subtract(Literal(2.toByte), columnC) =!= Literal(8.toByte)) + + val correctAnswer = testRelation + .where(Literal(-6.toByte) =!= columnC) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte with overflow risk: c - 10 >= Byte.MaxValue - 2") { + val query = testRelation + .where(Subtract(columnC, Literal(10.toByte)) >= Literal(Byte.MaxValue - 2.toByte)) + + val correctAnswer = testRelation + .where(Subtract(columnC, Literal(10.toByte)) >= Literal(Byte.MaxValue - 2.toByte)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte with overflow risk: 10 - c >= Byte.MinValue") { + val query = testRelation + .where(Subtract(Literal(10.toByte), columnC) >= Literal(Byte.MinValue)) + + val correctAnswer = testRelation + .where(Subtract(Literal(10.toByte), columnC) >= Literal(Byte.MinValue)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of byte with overflow risk: c + 10 <= Byte.MinValue + 2") { + val query = testRelation + .where(Add(columnC, Literal(10.toByte)) <= Literal(Byte.MinValue + 2.toByte)) + + val correctAnswer = testRelation + .where(Add(columnC, Literal(10.toByte)) <= Literal(Byte.MinValue + 2.toByte)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short: d + 2 = 8") { + val query = testRelation + .where(Add(columnD, Literal(2.toShort)) === Literal(8.toShort)) + + val correctAnswer = testRelation + .where(columnD === Literal(6.toShort)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short: d + 2 >= 8") { + val query = testRelation + .where(Add(columnD, Literal(2.toShort)) >= Literal(8.toShort)) + + val correctAnswer = testRelation + .where(columnD >= Literal(6.toShort)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short: d + 2 <= 8") { + val query = testRelation + .where(Add(columnD, Literal(2.toShort)) <= Literal(8.toShort)) + + val correctAnswer = testRelation + .where(columnD <= Literal(6.toShort)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short: d - 2 <= 8") { + val query = testRelation + .where(Subtract(columnD, Literal(2.toShort)) <= Literal(8.toShort)) + + val correctAnswer = testRelation + .where(columnD <= Literal(10.toShort)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short: 2 - d <= 8") { + val query = testRelation + .where(Subtract(Literal(2.toShort), columnD) <= Literal(8.toShort)) + + val correctAnswer = testRelation + .where(Literal(-6.toShort) <= columnD) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short: 2 - d >= 8") { + val query = testRelation + .where(Subtract(Literal(2.toShort), columnD) >= Literal(8.toShort)) + + val correctAnswer = testRelation + .where(Literal(-6.toShort) >= columnD) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short: 2 - d =!= 8") { + val query = testRelation + .where(Subtract(Literal(2.toShort), columnD) =!= Literal(8.toShort)) + + val correctAnswer = testRelation + .where(Literal(-6.toShort) =!= columnD) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short with overflow risk: d - 10 >= Short.MaxValue - 2") { + val query = testRelation + .where(Subtract(columnD, Literal(10.toShort)) >= Literal(Short.MaxValue - 2.toShort)) + + val correctAnswer = testRelation + .where(Subtract(columnD, Literal(10.toShort)) >= Literal(Short.MaxValue - 2.toShort)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short with overflow risk: 10 - d >= Short.MinValue") { + val query = testRelation + .where(Subtract(Literal(10.toShort), columnD) >= Literal(Short.MinValue)) + + val correctAnswer = testRelation + .where(Subtract(Literal(10.toShort), columnD) >= Literal(Short.MinValue)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } + + test("test of short with overflow risk: d + 10 <= Short.MinValue + 2") { + val query = testRelation + .where(Add(columnD, Literal(10.toShort)) <= Literal(Short.MinValue + 2.toShort)) + + val correctAnswer = testRelation + .where(Add(columnD, Literal(10.toShort)) <= Literal(Short.MinValue + 2.toShort)) + .analyze + + comparePlans(Optimize.execute(query.analyze), correctAnswer) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala index adb06813f8092..198041cafea89 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY +import org.apache.spark.sql.internal.SQLConf.{OPTIMIZER_EXCLUDED_RULES, OPTIMIZER_METADATA_ONLY} import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -60,16 +60,19 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto } test("SPARK-23877: filter on projected expression") { - withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { + // exclude `RewriteArithmeticFiltersOnIntegralColumn` here because + // it will optimize part + 1 < 5 to part < 4 and then pushed to metastore + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true", + OPTIMIZER_EXCLUDED_RULES.key -> "RewriteArithmeticFiltersOnIntegralColumn") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the matching partitions val partitions = spark.internalCreateDataFrame(Distinct(Filter(($"x" < 5).expr, - Project(Seq(($"part" * 1).as("x").expr.asInstanceOf[NamedExpression]), + Project(Seq(($"part" + 1).as("x").expr.asInstanceOf[NamedExpression]), spark.table("metadata_only").logicalPlan.asInstanceOf[SubqueryAlias].child))) .queryExecution.toRdd, StructType(Seq(StructField("x", IntegerType)))) - checkAnswer(partitions, Seq(0, 1, 2, 3, 4).toDF("x")) + checkAnswer(partitions, Seq(1, 2, 3, 4).toDF("x")) // verify that the partition predicate was not pushed down to the metastore assert(HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount - startCount == 11) From 2c0277748f2d20ee7fc3bf0a85279b4c72cd24e4 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Fri, 8 Mar 2019 15:45:55 +0800 Subject: [PATCH 09/10] use spark.sql.optimizer.excludedRules --- .../spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala index 198041cafea89..e1a444910687b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.NamedExpression +import org.apache.spark.sql.catalyst.optimizer.RewriteArithmeticFiltersOnIntegralColumn import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf.{OPTIMIZER_EXCLUDED_RULES, OPTIMIZER_METADATA_ONLY} @@ -63,7 +64,7 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto // exclude `RewriteArithmeticFiltersOnIntegralColumn` here because // it will optimize part + 1 < 5 to part < 4 and then pushed to metastore withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true", - OPTIMIZER_EXCLUDED_RULES.key -> "RewriteArithmeticFiltersOnIntegralColumn") { + OPTIMIZER_EXCLUDED_RULES.key -> RewriteArithmeticFiltersOnIntegralColumn.ruleName) { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the matching partitions From 43892acab8a3508cbb55ce2012ff5a6487fab181 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Thu, 14 Mar 2019 23:20:25 +0800 Subject: [PATCH 10/10] only rewrite EqualTo --- ...iteArithmeticFiltersOnIntegralColumn.scala | 11 +- ...ithmeticFiltersOnIntegralColumnSuite.scala | 292 +++--------------- .../org/apache/spark/sql/jdbc/JDBCSuite.scala | 2 +- .../OptimizeHiveMetadataOnlyQuerySuite.scala | 8 +- 4 files changed, 45 insertions(+), 268 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumn.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumn.scala index f575d5d01e544..8cd3fa46829fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumn.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumn.scala @@ -40,7 +40,8 @@ import org.apache.spark.sql.types._ * * Note: * 1. This rule supports `Add` and `Subtract` in arithmetic expressions. - * 2. This rule supports `=`, `>=`, `<=`, `>`, `<`, and `!=` in comparators. + * 2. This rule supports `=` and `!=` in comparators. For `>`, `>=`, `<`, `<=`, + * it may brings inconsistencies after rewrite. * 3. This rule supports integral-type (`byte`, `short`, `int`, `long`) only. * It doesn't support `float` or `double` because of precision issues. */ @@ -48,17 +49,17 @@ object RewriteArithmeticFiltersOnIntegralColumn extends Rule[LogicalPlan] with P def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f: Filter => f transformExpressionsUp { - case e @ BinaryComparison(left: BinaryArithmetic, right: Expression) + case e @ EqualTo(left: BinaryArithmetic, right: Expression) if right.foldable && isDataTypeSafe(left.dataType) => transformLeft(e, left, right) - case e @ BinaryComparison(left: Expression, right: BinaryArithmetic) + case e @ EqualTo(left: Expression, right: BinaryArithmetic) if left.foldable && isDataTypeSafe(right.dataType) => transformRight(e, left, right) } } private def transformLeft( - bc: BinaryComparison, + bc: EqualTo, left: BinaryArithmetic, right: Expression): Expression = { left match { @@ -75,7 +76,7 @@ object RewriteArithmeticFiltersOnIntegralColumn extends Rule[LogicalPlan] with P } private def transformRight( - bc: BinaryComparison, + bc: EqualTo, left: Expression, right: BinaryArithmetic): Expression = { right match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumnSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumnSuite.scala index 9c44c6c1d82ef..dfef8990b522c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumnSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteArithmeticFiltersOnIntegralColumnSuite.scala @@ -56,61 +56,6 @@ class RewriteArithmeticFiltersOnIntegralColumnSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of int: a + 2 >= 8") { - val query = testRelation - .where(Add(columnA, Literal(2)) >= Literal(8)) - - val correctAnswer = testRelation - .where(columnA >= Literal(6)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int: a + 2 <= 8") { - val query = testRelation - .where(Add(columnA, Literal(2)) <= Literal(8)) - - val correctAnswer = testRelation - .where(columnA <= Literal(6)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int: a - 2 <= 8") { - val query = testRelation - .where(Subtract(columnA, Literal(2)) <= Literal(8)) - - val correctAnswer = testRelation - .where(columnA <= Literal(10)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int: 2 - a <= 8") { - val query = testRelation - .where(Subtract(Literal(2), columnA) <= Literal(8)) - - val correctAnswer = testRelation - .where(Literal(-6) <= columnA) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of int: 2 - a >= 8") { - val query = testRelation - .where(Subtract(Literal(2), columnA) >= Literal(8)) - - val correctAnswer = testRelation - .where(Literal(-6) >= columnA) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - test("test of int: 2 - a =!= 8") { val query = testRelation .where(Subtract(Literal(2), columnA) =!= Literal(8)) @@ -122,34 +67,34 @@ class RewriteArithmeticFiltersOnIntegralColumnSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of int with overflow risk: a - 10 >= Int.MaxValue - 2") { + test("test of int with overflow risk: a - 10 = Int.MaxValue - 2") { val query = testRelation - .where(Subtract(columnA, Literal(10)) >= Literal(Int.MaxValue - 2)) + .where(Subtract(columnA, Literal(10)) === Literal(Int.MaxValue - 2)) val correctAnswer = testRelation - .where(Subtract(columnA, Literal(10)) >= Literal(Int.MaxValue - 2)) + .where(Subtract(columnA, Literal(10)) === Literal(Int.MaxValue - 2)) .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of int with overflow risk: 10 - a >= Int.MinValue") { + test("test of int with overflow risk: 10 - a = Int.MinValue") { val query = testRelation - .where(Subtract(Literal(10), columnA) >= Literal(Int.MinValue)) + .where(Subtract(Literal(10), columnA) === Literal(Int.MinValue)) val correctAnswer = testRelation - .where(Subtract(Literal(10), columnA) >= Literal(Int.MinValue)) + .where(Subtract(Literal(10), columnA) === Literal(Int.MinValue)) .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of int with overflow risk: a + 10 <= Int.MinValue + 2") { + test("test of int with overflow risk: a + 10 = Int.MinValue + 2") { val query = testRelation - .where(Add(columnA, Literal(10)) <= Literal(Int.MinValue + 2)) + .where(Add(columnA, Literal(10)) === Literal(Int.MinValue + 2)) val correctAnswer = testRelation - .where(Add(columnA, Literal(10)) <= Literal(Int.MinValue + 2)) + .where(Add(columnA, Literal(10)) === Literal(Int.MinValue + 2)) .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) @@ -166,61 +111,6 @@ class RewriteArithmeticFiltersOnIntegralColumnSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of long: b + 2L >= 8L") { - val query = testRelation - .where(Add(columnB, Literal(2L)) >= Literal(8L)) - - val correctAnswer = testRelation - .where(columnB >= Literal(6L)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long: b + 2L <= 8L") { - val query = testRelation - .where(Add(columnB, Literal(2L)) <= Literal(8L)) - - val correctAnswer = testRelation - .where(columnB <= Literal(6L)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long: b - 2L <= 8L") { - val query = testRelation - .where(Subtract(columnB, Literal(2L)) <= Literal(8L)) - - val correctAnswer = testRelation - .where(columnB <= Literal(10L)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long: 2L - b <= 8L") { - val query = testRelation - .where(Subtract(Literal(2L), columnB) <= Literal(8)) - - val correctAnswer = testRelation - .where(Literal(-6L) <= columnB) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of long: 2L - b >= 8L") { - val query = testRelation - .where(Subtract(Literal(2L), columnB) >= Literal(8)) - - val correctAnswer = testRelation - .where(Literal(-6L) >= columnB) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - test("test of long: 2L - b =!= 8L") { val query = testRelation .where(Subtract(Literal(2L), columnB) =!= Literal(8L)) @@ -232,34 +122,34 @@ class RewriteArithmeticFiltersOnIntegralColumnSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of long with overflow risk: b - 10L >= Long.MaxValue - 2") { + test("test of long with overflow risk: b - 10L = Long.MaxValue - 2") { val query = testRelation - .where(Subtract(columnB, Literal(10L)) >= Literal(Long.MaxValue - 2)) + .where(Subtract(columnB, Literal(10L)) === Literal(Long.MaxValue - 2)) val correctAnswer = testRelation - .where(Subtract(columnB, Literal(10L)) >= Literal(Long.MaxValue - 2)) + .where(Subtract(columnB, Literal(10L)) === Literal(Long.MaxValue - 2)) .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of long with overflow risk: 10L - b >= Long.MinValue") { + test("test of long with overflow risk: 10L - b = Long.MinValue") { val query = testRelation - .where(Subtract(Literal(10L), columnB) >= Literal(Long.MinValue)) + .where(Subtract(Literal(10L), columnB) === Literal(Long.MinValue)) val correctAnswer = testRelation - .where(Subtract(Literal(10L), columnB) >= Literal(Long.MinValue)) + .where(Subtract(Literal(10L), columnB) === Literal(Long.MinValue)) .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of long with overflow risk: bL + 10 <= Long.MinValue + 2") { + test("test of long with overflow risk: bL + 10 = Long.MinValue + 2") { val query = testRelation - .where(Add(columnB, Literal(10L)) <= Literal(Long.MinValue + 2)) + .where(Add(columnB, Literal(10L)) === Literal(Long.MinValue + 2)) val correctAnswer = testRelation - .where(Add(columnB, Literal(10L)) <= Literal(Long.MinValue + 2)) + .where(Add(columnB, Literal(10L)) === Literal(Long.MinValue + 2)) .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) @@ -276,61 +166,6 @@ class RewriteArithmeticFiltersOnIntegralColumnSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of byte: c + 2 >= 8") { - val query = testRelation - .where(Add(columnC, Literal(2.toByte)) >= Literal(8.toByte)) - - val correctAnswer = testRelation - .where(columnC >= Literal(6.toByte)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of byte: c + 2 <= 8") { - val query = testRelation - .where(Add(columnC, Literal(2.toByte)) <= Literal(8.toByte)) - - val correctAnswer = testRelation - .where(columnC <= Literal(6.toByte)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of byte: c - 2 <= 8") { - val query = testRelation - .where(Subtract(columnC, Literal(2.toByte)) <= Literal(8.toByte)) - - val correctAnswer = testRelation - .where(columnC <= Literal(10.toByte)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of byte: 2 - c <= 8") { - val query = testRelation - .where(Subtract(Literal(2.toByte), columnC) <= Literal(8.toByte)) - - val correctAnswer = testRelation - .where(Literal(-6.toByte) <= columnC) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of byte: 2 - c >= 8") { - val query = testRelation - .where(Subtract(Literal(2.toByte), columnC) >= Literal(8.toByte)) - - val correctAnswer = testRelation - .where(Literal(-6.toByte) >= columnC) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - test("test of byte: 2 - c =!= 8") { val query = testRelation .where(Subtract(Literal(2.toByte), columnC) =!= Literal(8.toByte)) @@ -342,34 +177,34 @@ class RewriteArithmeticFiltersOnIntegralColumnSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of byte with overflow risk: c - 10 >= Byte.MaxValue - 2") { + test("test of byte with overflow risk: c - 10 = Byte.MaxValue - 2") { val query = testRelation - .where(Subtract(columnC, Literal(10.toByte)) >= Literal(Byte.MaxValue - 2.toByte)) + .where(Subtract(columnC, Literal(10.toByte)) === Literal(Byte.MaxValue - 2.toByte)) val correctAnswer = testRelation - .where(Subtract(columnC, Literal(10.toByte)) >= Literal(Byte.MaxValue - 2.toByte)) + .where(Subtract(columnC, Literal(10.toByte)) === Literal(Byte.MaxValue - 2.toByte)) .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of byte with overflow risk: 10 - c >= Byte.MinValue") { + test("test of byte with overflow risk: 10 - c = Byte.MinValue") { val query = testRelation - .where(Subtract(Literal(10.toByte), columnC) >= Literal(Byte.MinValue)) + .where(Subtract(Literal(10.toByte), columnC) === Literal(Byte.MinValue)) val correctAnswer = testRelation - .where(Subtract(Literal(10.toByte), columnC) >= Literal(Byte.MinValue)) + .where(Subtract(Literal(10.toByte), columnC) === Literal(Byte.MinValue)) .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of byte with overflow risk: c + 10 <= Byte.MinValue + 2") { + test("test of byte with overflow risk: c + 10 = Byte.MinValue + 2") { val query = testRelation - .where(Add(columnC, Literal(10.toByte)) <= Literal(Byte.MinValue + 2.toByte)) + .where(Add(columnC, Literal(10.toByte)) === Literal(Byte.MinValue + 2.toByte)) val correctAnswer = testRelation - .where(Add(columnC, Literal(10.toByte)) <= Literal(Byte.MinValue + 2.toByte)) + .where(Add(columnC, Literal(10.toByte)) === Literal(Byte.MinValue + 2.toByte)) .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) @@ -386,61 +221,6 @@ class RewriteArithmeticFiltersOnIntegralColumnSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of short: d + 2 >= 8") { - val query = testRelation - .where(Add(columnD, Literal(2.toShort)) >= Literal(8.toShort)) - - val correctAnswer = testRelation - .where(columnD >= Literal(6.toShort)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of short: d + 2 <= 8") { - val query = testRelation - .where(Add(columnD, Literal(2.toShort)) <= Literal(8.toShort)) - - val correctAnswer = testRelation - .where(columnD <= Literal(6.toShort)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of short: d - 2 <= 8") { - val query = testRelation - .where(Subtract(columnD, Literal(2.toShort)) <= Literal(8.toShort)) - - val correctAnswer = testRelation - .where(columnD <= Literal(10.toShort)) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of short: 2 - d <= 8") { - val query = testRelation - .where(Subtract(Literal(2.toShort), columnD) <= Literal(8.toShort)) - - val correctAnswer = testRelation - .where(Literal(-6.toShort) <= columnD) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - - test("test of short: 2 - d >= 8") { - val query = testRelation - .where(Subtract(Literal(2.toShort), columnD) >= Literal(8.toShort)) - - val correctAnswer = testRelation - .where(Literal(-6.toShort) >= columnD) - .analyze - - comparePlans(Optimize.execute(query.analyze), correctAnswer) - } - test("test of short: 2 - d =!= 8") { val query = testRelation .where(Subtract(Literal(2.toShort), columnD) =!= Literal(8.toShort)) @@ -452,34 +232,34 @@ class RewriteArithmeticFiltersOnIntegralColumnSuite extends PlanTest { comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of short with overflow risk: d - 10 >= Short.MaxValue - 2") { + test("test of short with overflow risk: d - 10 = Short.MaxValue - 2") { val query = testRelation - .where(Subtract(columnD, Literal(10.toShort)) >= Literal(Short.MaxValue - 2.toShort)) + .where(Subtract(columnD, Literal(10.toShort)) === Literal(Short.MaxValue - 2.toShort)) val correctAnswer = testRelation - .where(Subtract(columnD, Literal(10.toShort)) >= Literal(Short.MaxValue - 2.toShort)) + .where(Subtract(columnD, Literal(10.toShort)) === Literal(Short.MaxValue - 2.toShort)) .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of short with overflow risk: 10 - d >= Short.MinValue") { + test("test of short with overflow risk: 10 - d = Short.MinValue") { val query = testRelation - .where(Subtract(Literal(10.toShort), columnD) >= Literal(Short.MinValue)) + .where(Subtract(Literal(10.toShort), columnD) === Literal(Short.MinValue)) val correctAnswer = testRelation - .where(Subtract(Literal(10.toShort), columnD) >= Literal(Short.MinValue)) + .where(Subtract(Literal(10.toShort), columnD) === Literal(Short.MinValue)) .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) } - test("test of short with overflow risk: d + 10 <= Short.MinValue + 2") { + test("test of short with overflow risk: d + 10 = Short.MinValue + 2") { val query = testRelation - .where(Add(columnD, Literal(10.toShort)) <= Literal(Short.MinValue + 2.toShort)) + .where(Add(columnD, Literal(10.toShort)) === Literal(Short.MinValue + 2.toShort)) val correctAnswer = testRelation - .where(Add(columnD, Literal(10.toShort)) <= Literal(Short.MinValue + 2.toShort)) + .where(Add(columnD, Literal(10.toShort)) === Literal(Short.MinValue + 2.toShort)) .analyze comparePlans(Optimize.execute(query.analyze), correctAnswer) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d401de14b1d14..7899eaaa6450a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -344,8 +344,8 @@ class JDBCSuite extends QueryTest "WHERE (THEID > 0 AND TRIM(NAME) = 'mary') OR (NAME = 'fred')") assert(df2.collect.toSet === Set(Row("fred", 1), Row("mary", 2))) + assert(checkNotPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) // SPARK-27033: Add Optimize rule RewriteArithmeticFiltersOnIntOrLongColumn - assert(checkPushdown(sql("SELECT * FROM foobar WHERE (THEID + 1) < 2")).collect().size == 0) assert(checkPushdown(sql("SELECT * FROM foobar WHERE (THEID + 2) != 4")).collect().size == 2) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala index e1a444910687b..1e525c46a9cfb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/OptimizeHiveMetadataOnlyQuerySuite.scala @@ -22,10 +22,9 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.QueryTest import org.apache.spark.sql.catalyst.expressions.NamedExpression -import org.apache.spark.sql.catalyst.optimizer.RewriteArithmeticFiltersOnIntegralColumn import org.apache.spark.sql.catalyst.plans.logical.{Distinct, Filter, Project, SubqueryAlias} import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.SQLConf.{OPTIMIZER_EXCLUDED_RULES, OPTIMIZER_METADATA_ONLY} +import org.apache.spark.sql.internal.SQLConf.OPTIMIZER_METADATA_ONLY import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -61,10 +60,7 @@ class OptimizeHiveMetadataOnlyQuerySuite extends QueryTest with TestHiveSingleto } test("SPARK-23877: filter on projected expression") { - // exclude `RewriteArithmeticFiltersOnIntegralColumn` here because - // it will optimize part + 1 < 5 to part < 4 and then pushed to metastore - withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true", - OPTIMIZER_EXCLUDED_RULES.key -> RewriteArithmeticFiltersOnIntegralColumn.ruleName) { + withSQLConf(OPTIMIZER_METADATA_ONLY.key -> "true") { val startCount = HiveCatalogMetrics.METRIC_PARTITIONS_FETCHED.getCount // verify the matching partitions