Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,21 @@ trait HiveTypeCoercion {
BooleanCasts ::
StringToIntegralCasts ::
FunctionArgumentConversion ::
CastNulls ::
CaseWhenCoercion ::
Division ::
Nil

trait TypeWidening {
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
// Try and find a promotion rule that contains both types in question.
val applicableConversion =
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))

// If found return the widest common type, otherwise None
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
}
}

/**
* Applies any changes to [[AttributeReference]] data types that are made by other rules to
* instances higher in the query tree.
Expand Down Expand Up @@ -133,16 +144,7 @@ trait HiveTypeCoercion {
* - LongType to FloatType
* - LongType to DoubleType
*/
object WidenTypes extends Rule[LogicalPlan] {

def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
// Try and find a promotion rule that contains both types in question.
val applicableConversion =
HiveTypeCoercion.allPromotions.find(p => p.contains(t1) && p.contains(t2))

// If found return the widest common type, otherwise None
applicableConversion.map(_.filter(t => t == t1 || t == t2).last)
}
object WidenTypes extends Rule[LogicalPlan] with TypeWidening {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
Expand Down Expand Up @@ -336,28 +338,34 @@ trait HiveTypeCoercion {
}

/**
* Ensures that NullType gets casted to some other types under certain circumstances.
* Coerces the type of different branches of a CASE WHEN statement to a common type.
*/
object CastNulls extends Rule[LogicalPlan] {
object CaseWhenCoercion extends Rule[LogicalPlan] with TypeWidening {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case cw @ CaseWhen(branches) =>
case cw @ CaseWhen(branches) if !cw.resolved && !branches.exists(!_.resolved) =>
val valueTypes = branches.sliding(2, 2).map {
case Seq(_, value) if value.resolved => Some(value.dataType)
case Seq(elseVal) if elseVal.resolved => Some(elseVal.dataType)
case _ => None
case Seq(_, value) => value.dataType
case Seq(elseVal) => elseVal.dataType
}.toSeq
if (valueTypes.distinct.size == 2 && valueTypes.exists(_ == Some(NullType))) {
val otherType = valueTypes.filterNot(_ == Some(NullType))(0).get

logDebug(s"Input values for null casting ${valueTypes.mkString(",")}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to leave this in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its essentially free and could be useful to turn on if we ever have problems with this rule again.


if (valueTypes.distinct.size > 1) {
val commonType = valueTypes.reduce { (v1, v2) =>
findTightestCommonType(v1, v2)
.getOrElse(sys.error(
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
}
val transformedBranches = branches.sliding(2, 2).map {
case Seq(cond, value) if value.resolved && value.dataType == NullType =>
Seq(cond, Cast(value, otherType))
case Seq(elseVal) if elseVal.resolved && elseVal.dataType == NullType =>
Seq(Cast(elseVal, otherType))
case Seq(cond, value) if value.dataType != commonType =>
Seq(cond, Cast(value, commonType))
case Seq(elseVal) if elseVal.dataType != commonType =>
Seq(Cast(elseVal, commonType))
case s => s
}.reduce(_ ++ _)
CaseWhen(transformedBranches)
} else {
// It is possible to have more types due to the possibility of short-circuiting.
// Types match up. Hopefully some other rule fixes whatever is wrong with resolution.
cw
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
1
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ case class TestData(a: Int, b: String)
*/
class HiveQuerySuite extends HiveComparisonTest {

createQueryTest("null case",
"SELECT case when(true) then 1 else null end FROM src LIMIT 1")

createQueryTest("single case",
"""SELECT case when true then 1 else 2 end FROM src LIMIT 1""")

Expand Down