diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 021952e7166f..e6ac0399b674 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -72,23 +72,22 @@ object TypeCoercion { FloatType, DoubleType) - /** - * Case 1 type widening (see the classdoc comment above for TypeCoercion). - * - * Find the tightest common type of two types that might be used in a binary expression. - * This handles all numeric types except fixed-precision decimals interacting with each other or - * with primitive types, because in that case the precision and scale of the result depends on - * the operation. Those rules are implemented in [[DecimalPrecision]]. - */ - val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = { + def findWiderType(left: DataType, right: DataType): Option[DataType] = (left, right) match { case (t1, t2) if t1 == t2 => Some(t1) + case (NullType, t1) => Some(t1) case (t1, NullType) => Some(t1) - case (t1: IntegralType, t2: DecimalType) if t2.isWiderThan(t1) => - Some(t2) - case (t1: DecimalType, t2: IntegralType) if t1.isWiderThan(t2) => - Some(t1) + case (t1: DecimalType, t2: DecimalType) => + Some(DecimalPrecision.widerDecimalType(t1, t2)) + + case (t: IntegralType, d: DecimalType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + case (d: DecimalType, t: IntegralType) => + Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) + + case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => + Some(DoubleType) // Promote numeric types to the highest of the two case (t1: NumericType, t2: NumericType) @@ -99,70 +98,24 @@ object TypeCoercion { case _ => None } - /** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */ - def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = { - findTightestCommonTypeOfTwo(left, right).orElse((left, right) match { - case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) - case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) - case _ => None - }) - } - - /** - * Find the tightest common type of a set of types by continuously applying - * `findTightestCommonTypeOfTwo` on these types. - */ - private def findTightestCommonType(types: Seq[DataType]): Option[DataType] = { + def findWiderType(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case Some(d) => findWiderType(d, c) case None => None - case Some(d) => findTightestCommonTypeOfTwo(d, c) }) } - /** - * Case 2 type widening (see the classdoc comment above for TypeCoercion). - * - * i.e. the main difference with [[findTightestCommonTypeOfTwo]] is that here we allow some - * loss of precision when widening decimal and double. - */ - private def findWiderTypeForTwo(t1: DataType, t2: DataType): Option[DataType] = (t1, t2) match { - case (t1: DecimalType, t2: DecimalType) => - Some(DecimalPrecision.widerDecimalType(t1, t2)) - case (t: IntegralType, d: DecimalType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (d: DecimalType, t: IntegralType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => - Some(DoubleType) - case _ => - findTightestCommonTypeToString(t1, t2) - } - - private def findWiderCommonType(types: Seq[DataType]) = { - types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findWiderTypeForTwo(d, c) - case None => None + val widerTypeToString: (DataType, DataType) => Option[DataType] = (left, right) => { + findWiderType(left, right).orElse((left, right) match { + case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) + case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) + case _ => None }) } - /** - * Similar to [[findWiderCommonType]], but can't promote to string. This is also similar to - * [[findTightestCommonType]], but can handle decimal types. If the wider decimal type exceeds - * system limitation, this rule will truncate the decimal type before return it. - */ - private def findWiderTypeWithoutStringPromotion(types: Seq[DataType]): Option[DataType] = { + def widerTypeToString(types: Seq[DataType]): Option[DataType] = { types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { - case Some(d) => findTightestCommonTypeOfTwo(d, c).orElse((d, c) match { - case (t1: DecimalType, t2: DecimalType) => - Some(DecimalPrecision.widerDecimalType(t1, t2)) - case (t: IntegralType, d: DecimalType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (d: DecimalType, t: IntegralType) => - Some(DecimalPrecision.widerDecimalType(DecimalType.forType(t), d)) - case (_: FractionalType, _: DecimalType) | (_: DecimalType, _: FractionalType) => - Some(DoubleType) - case _ => None - }) + case Some(d) => widerTypeToString(d, c) case None => None }) } @@ -272,7 +225,7 @@ object TypeCoercion { if (attrIndex >= children.head.output.length) return castedTypes.toSeq // For the attrIndex-th attribute, find the widest type - findWiderCommonType(children.map(_.output(attrIndex).dataType)) match { + widerTypeToString(children.map(_.output(attrIndex).dataType)) match { // If unable to find an appropriate widen type for this column, return an empty Seq case None => Seq.empty[DataType] // Otherwise, record the result in the queue and find the type for the next column @@ -374,7 +327,7 @@ object TypeCoercion { case e if !e.childrenResolved => e case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType)) match { + widerTypeToString(i.children.map(_.dataType)) match { case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) case None => i } @@ -450,7 +403,7 @@ object TypeCoercion { case a @ CreateArray(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderCommonType(types) match { + widerTypeToString(types) match { case Some(finalDataType) => CreateArray(children.map(Cast(_, finalDataType))) case None => a } @@ -461,7 +414,7 @@ object TypeCoercion { m.keys } else { val types = m.keys.map(_.dataType) - findWiderCommonType(types) match { + widerTypeToString(types) match { case Some(finalDataType) => m.keys.map(Cast(_, finalDataType)) case None => m.keys } @@ -471,7 +424,7 @@ object TypeCoercion { m.values } else { val types = m.values.map(_.dataType) - findWiderCommonType(types) match { + widerTypeToString(types) match { case Some(finalDataType) => m.values.map(Cast(_, finalDataType)) case None => m.values } @@ -499,7 +452,7 @@ object TypeCoercion { // compatible with every child column. case c @ Coalesce(es) if !haveSameType(es) => val types = es.map(_.dataType) - findWiderCommonType(types) match { + widerTypeToString(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => c } @@ -509,14 +462,14 @@ object TypeCoercion { // string.g case g @ Greatest(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types) match { + findWiderType(types) match { case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType))) case None => g } case l @ Least(children) if !haveSameType(children) => val types = children.map(_.dataType) - findWiderTypeWithoutStringPromotion(types) match { + findWiderType(types) match { case Some(finalDataType) => Least(children.map(Cast(_, finalDataType))) case None => l } @@ -556,7 +509,7 @@ object TypeCoercion { object CaseWhenCoercion extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { case c: CaseWhen if c.childrenResolved && !c.valueTypesEqual => - val maybeCommonType = findWiderCommonType(c.valueTypes) + val maybeCommonType = widerTypeToString(c.valueTypes) maybeCommonType.map { commonType => var changed = false val newBranches = c.branches.map { case (condition, value) => @@ -588,7 +541,7 @@ object TypeCoercion { case e if !e.childrenResolved => e // Find tightest common type for If, if the true value and false value have different types. case i @ If(pred, left, right) if left.dataType != right.dataType => - findWiderTypeForTwo(left.dataType, right.dataType).map { widestType => + widerTypeToString(left.dataType, right.dataType).map { widestType => val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) val newRight = if (right.dataType == widestType) right else Cast(right, widestType) If(pred, newLeft, newRight) @@ -630,7 +583,7 @@ object TypeCoercion { case e if !e.childrenResolved => e case b @ BinaryOperator(left, right) if left.dataType != right.dataType => - findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType => + findWiderType(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { // If the expression accepts the tightest common type, cast to that. val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 1c18265e0fed..4fed271ba1c4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -96,7 +96,7 @@ case class IfNull(left: Expression, right: Expression) extends RuntimeReplaceabl override def replaceForTypeCoercion(): Expression = { if (left.dataType != right.dataType) { - TypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype => + TypeCoercion.findWiderType(left.dataType, right.dataType).map { dtype => copy(left = Cast(left, dtype), right = Cast(right, dtype)) }.getOrElse(this) } else { @@ -116,7 +116,7 @@ case class NullIf(left: Expression, right: Expression) extends RuntimeReplaceabl override def replaceForTypeCoercion(): Expression = { if (left.dataType != right.dataType) { - TypeCoercion.findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { dtype => + TypeCoercion.findWiderType(left.dataType, right.dataType).map { dtype => copy(left = Cast(left, dtype), right = Cast(right, dtype)) }.getOrElse(this) } else { @@ -134,7 +134,7 @@ case class Nvl(left: Expression, right: Expression) extends RuntimeReplaceable { override def replaceForTypeCoercion(): Expression = { if (left.dataType != right.dataType) { - TypeCoercion.findTightestCommonTypeToString(left.dataType, right.dataType).map { dtype => + TypeCoercion.widerTypeToString(left.dataType, right.dataType).map { dtype => copy(left = Cast(left, dtype), right = Cast(right, dtype)) }.getOrElse(this) } else { @@ -154,7 +154,7 @@ case class Nvl2(expr1: Expression, expr2: Expression, expr3: Expression) override def replaceForTypeCoercion(): Expression = { if (expr2.dataType != expr3.dataType) { - TypeCoercion.findTightestCommonTypeOfTwo(expr2.dataType, expr3.dataType).map { dtype => + TypeCoercion.findWiderType(expr2.dataType, expr3.dataType).map { dtype => copy(expr2 = Cast(expr2, dtype), expr3 = Cast(expr3, dtype)) }.getOrElse(this) } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index a13c45fe2ffe..c4bd40b76dda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -131,11 +131,11 @@ class TypeCoercionSuite extends PlanTest { test("tightest common bound for types") { def widenTest(t1: DataType, t2: DataType, tightestCommon: Option[DataType]) { - var found = TypeCoercion.findTightestCommonTypeOfTwo(t1, t2) + var found = TypeCoercion.findWiderType(t1, t2) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = TypeCoercion.findTightestCommonTypeOfTwo(t2, t1) + found = TypeCoercion.findWiderType(t2, t1) assert(found == tightestCommon, s"Expected $tightestCommon as tightest common type for $t2 and $t1, found $found") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 91c58d059d28..6389a15249bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -252,7 +252,7 @@ private[sql] object InferSchema { * Returns the most general data type for two given data types. */ def compatibleType(t1: DataType, t2: DataType): DataType = { - TypeCoercion.findTightestCommonTypeOfTwo(t1, t2).getOrElse { + TypeCoercion.findWiderType(t1, t2).getOrElse { // t1 or t2 is a StructType, ArrayType, or an unexpected type. (t1, t2) match { // Double support larger range than fixed decimal, DecimalType.Maximum should be enough