Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,22 @@ object TypeCoercion {
FloatType,
DoubleType)

/**
* Case 1 type widening (see the classdoc comment above for TypeCoercion).
*
* Find the tightest common type of two types that might be used in a binary expression.
* This handles all numeric types except fixed-precision decimals interacting with each other or
* with primitive types, because in that case the precision and scale of the result depends on
* the operation. Those rules are implemented in [[DecimalPrecision]].
*/
val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = {
def findWiderType(left: DataType, right: DataType): Option[DataType] = (left, right) match {
case (t1, t2) if t1 == t2 => Some(t1)

case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)

case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) =>
Some(t2)
case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) =>
Some(t1)
case (t1: DecimalType, t2: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(t1, t2))

case (t: IntegralType, d: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (d: DecimalType, t: IntegralType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))

case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
Some(DoubleType)

// Promote numeric types to the highest of the two
case (t1: NumericType, t2: NumericType)
Expand All @@ -99,70 +98,24 @@ object TypeCoercion {
case _ => None
}

/** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */
def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = {
findTightestCommonTypeOfTwo(left, right).orElse((left, right) match {
case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
case _ => None
})
}

/**
* Find the tightest common type of a set of types by continuously applying
* `findTightestCommonTypeOfTwo` on these types.
*/
private def findTightestCommonType(types: Seq[DataType]): Option[DataType] = {
def findWiderType(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findWiderType(d, c)
case None => None
case Some(d) => findTightestCommonTypeOfTwo(d, c)
})
}

/**
* Case 2 type widening (see the classdoc comment above for TypeCoercion).
*
* i.e. the main difference with [[findTightestCommonTypeOfTwo]] is that here we allow some
* loss of precision when widening decimal and double.
*/
private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match {
case (t1: DecimalType, t2: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(t1, t2))
case (t: IntegralType, d: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (d: DecimalType, t: IntegralType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
Some(DoubleType)
case _ =>
findTightestCommonTypeToString(t1, t2)
}

private def findWiderCommonType(types: Seq[DataType]) = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findWiderTypeForTwo(d, c)
case None => None
val widerTypeToString: (DataType, DataType) => Option[DataType] = (left, right) => {
findWiderType(left, right).orElse((left, right) match {
case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
case _ => None
})
}

/**
* Similar to [[findWiderCommonType]], but can't promote to string. This is also similar to
* [[findTightestCommonType]], but can handle decimal types. If the wider decimal type exceeds
* system limitation, this rule will truncate the decimal type before return it.
*/
private def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = {
def widerTypeToString(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case Some(d) => findTightestCommonTypeOfTwo(d, c).orElse((d, c) match {
case (t1: DecimalType, t2: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(t1, t2))
case (t: IntegralType, d: DecimalType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (d: DecimalType, t: IntegralType) =>
Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d))
case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) =>
Some(DoubleType)
case _ => None
})
case Some(d) => widerTypeToString(d, c)
case None => None
})
}
Expand Down Expand Up @@ -272,7 +225,7 @@ object TypeCoercion {
if (attrIndex >= children.head.output.length) return castedTypes.toSeq

// For the attrIndex-th attribute, find the widest type
findWiderCommonType(children.map(_.output(attrIndex).dataType)) match {
widerTypeToString(children.map(_.output(attrIndex).dataType)) match {
// If unable to find an appropriate widen type for this column, return an empty Seq
case None => Seq.empty[DataType]
// Otherwise, record the result in the queue and find the type for the next column
Expand Down Expand Up @@ -374,7 +327,7 @@ object TypeCoercion {
case e if !e.childrenResolved => e

case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
findWiderCommonType(i.children.map(_.dataType)) match {
widerTypeToString(i.children.map(_.dataType)) match {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
case None => i
}
Expand Down Expand Up @@ -450,7 +403,7 @@ object TypeCoercion {

case a @ CreateArray(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
findWiderCommonType(types) match {
widerTypeToString(types) match {
case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType)))
case None => a
}
Expand All @@ -461,7 +414,7 @@ object TypeCoercion {
m.keys
} else {
val types = m.keys.map(_.dataType)
findWiderCommonType(types) match {
widerTypeToString(types) match {
case Some(finalDataType) => m.keys.map(Cast(_, finalDataType))
case None => m.keys
}
Expand All @@ -471,7 +424,7 @@ object TypeCoercion {
m.values
} else {
val types = m.values.map(_.dataType)
findWiderCommonType(types) match {
widerTypeToString(types) match {
case Some(finalDataType) => m.values.map(Cast(_, finalDataType))
case None => m.values
}
Expand Down Expand Up @@ -499,7 +452,7 @@ object TypeCoercion {
// compatible with every child column.
case c @ Coalesce(es) if !haveSameType(es) =>
val types = es.map(_.dataType)
findWiderCommonType(types) match {
widerTypeToString(types) match {
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
case None => c
}
Expand All @@ -509,14 +462,14 @@ object TypeCoercion {
// string.g
case g @ Greatest(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
findWiderTypeWithoutStringPromotion(types) match {
findWiderType(types) match {
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
case None => g
}

case l @ Least(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
findWiderTypeWithoutStringPromotion(types) match {
findWiderType(types) match {
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
case None => l
}
Expand Down Expand Up @@ -556,7 +509,7 @@ object TypeCoercion {
object CaseWhenCoercion extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual =>
val maybeCommonType = findWiderCommonType(c.valueTypes)
val maybeCommonType = widerTypeToString(c.valueTypes)
maybeCommonType.map { commonType =>
var changed = false
val newBranches = c.branches.map { case (condition, value) =>
Expand Down Expand Up @@ -588,7 +541,7 @@ object TypeCoercion {
case e if !e.childrenResolved => e
// Find tightest common type for If, if the true value and false value have different types.
case i @ If(pred, left, right) if left.dataType != right.dataType =>
findWiderTypeForTwo(left.dataType, right.dataType).map { widestType =>
widerTypeToString(left.dataType, right.dataType).map { widestType =>
val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
If(pred, newLeft, newRight)
Expand Down Expand Up @@ -630,7 +583,7 @@ object TypeCoercion {
case e if !e.childrenResolved => e

case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
findWiderType(left.dataType, right.dataType).map { commonType =>
if (b.inputType.acceptsType(commonType)) {
// If the expression accepts the tightest common type, cast to that.
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ case class IfNull(left: Expression, right: Expression) extends RuntimeReplaceabl

override def replaceForTypeCoercion(): Expression = {
if (left.dataType != right.dataType) {
TypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype =>
TypeCoercion.findWiderType(left.dataType, right.dataType).map { dtype =>
copy(left = Cast(left, dtype), right = Cast(right, dtype))
}.getOrElse(this)
} else {
Expand All @@ -116,7 +116,7 @@ case class NullIf(left: Expression, right: Expression) extends RuntimeReplaceabl

override def replaceForTypeCoercion(): Expression = {
if (left.dataType != right.dataType) {
TypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype =>
TypeCoercion.findWiderType(left.dataType, right.dataType).map { dtype =>
copy(left = Cast(left, dtype), right = Cast(right, dtype))
}.getOrElse(this)
} else {
Expand All @@ -134,7 +134,7 @@ case class Nvl(left: Expression, right: Expression) extends RuntimeReplaceable {

override def replaceForTypeCoercion(): Expression = {
if (left.dataType != right.dataType) {
TypeCoercion.findTightestCommonTypeToString(left.dataType, right.dataType).map { dtype =>
TypeCoercion.widerTypeToString(left.dataType, right.dataType).map { dtype =>
copy(left = Cast(left, dtype), right = Cast(right, dtype))
}.getOrElse(this)
} else {
Expand All @@ -154,7 +154,7 @@ case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression)

override def replaceForTypeCoercion(): Expression = {
if (expr2.dataType != expr3.dataType) {
TypeCoercion.findTightestCommonTypeOfTwo(expr2.dataType, expr3.dataType).map { dtype =>
TypeCoercion.findWiderType(expr2.dataType, expr3.dataType).map { dtype =>
copy(expr2 = Cast(expr2, dtype), expr3 = Cast(expr3, dtype))
}.getOrElse(this)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,11 @@ class TypeCoercionSuite extends PlanTest {

test("tightest common bound for types") {
def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) {
var found = TypeCoercion.findTightestCommonTypeOfTwo(t1, t2)
var found = TypeCoercion.findWiderType(t1, t2)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found")
// Test both directions to make sure the widening is symmetric.
found = TypeCoercion.findTightestCommonTypeOfTwo(t2, t1)
found = TypeCoercion.findWiderType(t2, t1)
assert(found == tightestCommon,
s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ private[sql] object InferSchema {
* Returns the most general data type for two given data types.
*/
def compatibleType(t1: DataType, t2: DataType): DataType = {
TypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse {
TypeCoercion.findWiderType(t1, t2).getOrElse {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
// Double support larger range than fixed decimal, DecimalType.Maximum should be enough
Expand Down