From 647a23ffdba3d4a9c84d0d37e44d5ce43671d63e Mon Sep 17 00:00:00 2001 From: Yijie Shen Date: Thu, 13 Aug 2015 19:53:58 +0800 Subject: [PATCH] constant folding of IntegralType on binaryComparison --- .../sql/catalyst/optimizer/Optimizer.scala | 96 +++++++++++++++++++ .../optimizer/ConstantFoldingSuite.scala | 34 +++++++ 2 files changed, 130 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 4ab5ac2c61e3c..f69503e059c42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -382,6 +382,102 @@ object ConstantFolding extends Rule[LogicalPlan] { case Literal(candidate, _) if candidate == v => true case _ => false } => Literal.create(true, BooleanType) + + case EqualTo(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) && + (v.asInstanceOf[Number].longValue < minValue(a.dataType) || + v.asInstanceOf[Number].longValue > maxValue(a.dataType)) => + Literal.create(false, BooleanType) + case EqualTo(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) && + (v.asInstanceOf[Number].longValue < minValue(a.dataType) || + v.asInstanceOf[Number].longValue > maxValue(a.dataType)) => + Literal.create(false, BooleanType) + + case EqualNullSafe(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) && + (v.asInstanceOf[Number].longValue < minValue(a.dataType) || + v.asInstanceOf[Number].longValue > maxValue(a.dataType)) => + Literal.create(false, BooleanType) + case EqualNullSafe(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) && + (v.asInstanceOf[Number].longValue < minValue(a.dataType) || + v.asInstanceOf[Number].longValue > maxValue(a.dataType)) => + Literal.create(false, BooleanType) + + case GreaterThan(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) && + v.asInstanceOf[Number].longValue < minValue(a.dataType) => + Literal.create(true, BooleanType) + case GreaterThan(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) && + v.asInstanceOf[Number].longValue >= maxValue(a.dataType) => + Literal.create(false, BooleanType) + case GreaterThan(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) && + v.asInstanceOf[Number].longValue <= minValue(a.dataType) => + Literal.create(false, BooleanType) + case GreaterThan(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) && + v.asInstanceOf[Number].longValue > maxValue(a.dataType) => + Literal.create(true, BooleanType) + + case LessThan(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) && + v.asInstanceOf[Number].longValue <= minValue(a.dataType) => + Literal.create(false, BooleanType) + case LessThan(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) && + v.asInstanceOf[Number].longValue > maxValue(a.dataType) => + Literal.create(true, BooleanType) + case LessThan(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) && + v.asInstanceOf[Number].longValue < minValue(a.dataType) => + Literal.create(true, BooleanType) + case LessThan(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) && + v.asInstanceOf[Number].longValue >= maxValue(a.dataType) => + Literal.create(false, BooleanType) + + case GreaterThanOrEqual(c @ Cast(a: Attribute, _), Literal(v, _)) + if isUpCastingIntegral(c) && v.asInstanceOf[Number].longValue <= minValue(a.dataType) => + Literal.create(true, BooleanType) + case GreaterThanOrEqual(c @ Cast(a: Attribute, _), Literal(v, _)) + if isUpCastingIntegral(c) && v.asInstanceOf[Number].longValue > maxValue(a.dataType) => + Literal.create(false, BooleanType) + case GreaterThanOrEqual(Literal(v, _), c @ Cast(a: Attribute, _)) + if isUpCastingIntegral(c) && v.asInstanceOf[Number].longValue < minValue(a.dataType) => + Literal.create(false, BooleanType) + case GreaterThanOrEqual(Literal(v, _), c @ Cast(a: Attribute, _)) + if isUpCastingIntegral(c) && v.asInstanceOf[Number].longValue >= maxValue(a.dataType) => + Literal.create(true, BooleanType) + + case LessThanOrEqual(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) && + v.asInstanceOf[Number].longValue < minValue(a.dataType) => + Literal.create(false, BooleanType) + case LessThanOrEqual(c @ Cast(a: Attribute, _), Literal(v, _)) if isUpCastingIntegral(c) && + v.asInstanceOf[Number].longValue >= maxValue(a.dataType) => + Literal.create(true, BooleanType) + case LessThanOrEqual(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) && + v.asInstanceOf[Number].longValue <= minValue(a.dataType) => + Literal.create(true, BooleanType) + case LessThanOrEqual(Literal(v, _), c @ Cast(a: Attribute, _)) if isUpCastingIntegral(c) && + v.asInstanceOf[Number].longValue > maxValue(a.dataType) => + Literal.create(false, BooleanType) + } + } + + private val integralPrecedence = Seq(ByteType, ShortType, IntegerType, LongType) + + private def isUpCastingIntegral(c: Cast): Boolean = { + (c.child.dataType, c.dataType) match { + case (from: IntegralType, to: IntegralType) + if integralPrecedence.indexOf(from) < integralPrecedence.indexOf(to) => true + case _ => false + } + } + + private def maxValue(dataType: DataType): Long = { + dataType match { + case ByteType => Byte.MaxValue.toLong + case ShortType => Short.MaxValue.toLong + case IntegerType => Int.MaxValue.toLong + } + } + + private def minValue(dataType: DataType): Long = { + dataType match { + case ByteType => Byte.MinValue.toLong + case ShortType => Short.MinValue.toLong + case IntegerType => Int.MinValue.toLong } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index ec3b2f1edfa05..ef4f31d84a621 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -280,4 +280,38 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("binary comparison folding") { + val trueQuery = testRelation.select(Literal(true).as("r")) + val falseQuery = testRelation.select(Literal(false).as("r")) + def checkComparisonFolding(l: LogicalPlan, expected: Boolean): Unit = { + val optimized = Optimize.execute(l.analyze) + if (expected) { + comparePlans(optimized, trueQuery) + } else { + comparePlans(optimized, falseQuery) + } + } + + checkComparisonFolding( + testRelation.select(EqualTo('a, Int.MaxValue.toLong + 1L).as("r")), false) + checkComparisonFolding( + testRelation.select(EqualTo('a, Int.MinValue.toLong - 1L).as("r")), false) + checkComparisonFolding( + testRelation.select(LessThan('a, Int.MaxValue.toLong + 1L).as("r")), true) + checkComparisonFolding( + testRelation.select(LessThan('a, Int.MinValue.toLong - 1L).as("r")), false) + checkComparisonFolding( + testRelation.select(GreaterThan('a, Int.MaxValue.toLong + 1L).as("r")), false) + checkComparisonFolding( + testRelation.select(GreaterThan('a, Int.MinValue.toLong - 1L).as("r")), true) + checkComparisonFolding( + testRelation.select(LessThanOrEqual('a, Int.MaxValue.toLong).as("r")), true) + checkComparisonFolding( + testRelation.select(LessThanOrEqual('a, Int.MinValue.toLong - 1L).as("r")), false) + checkComparisonFolding( + testRelation.select(GreaterThanOrEqual('a, Int.MaxValue.toLong + 1L).as("r")), false) + checkComparisonFolding( + testRelation.select(GreaterThanOrEqual('a, Int.MinValue.toLong).as("r")), true) + } }