Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ package org.apache.spark.sql.catalyst.analysis

import javax.annotation.Nullable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types._
Expand All @@ -38,7 +36,7 @@ object HiveTypeCoercion {
val typeCoercionRules =
PropagateTypes ::
InConversion ::
WidenTypes ::
WidenSetOperationTypes ::
PromoteStrings ::
DecimalPrecision ::
BooleanEquality ::
Expand Down Expand Up @@ -175,7 +173,7 @@ object HiveTypeCoercion {
*
* This rule is only applied to Union/Except/Intersect
*/
object WidenTypes extends Rule[LogicalPlan] {
object WidenSetOperationTypes extends Rule[LogicalPlan] {

private[this] def widenOutputTypes(
planName: String,
Expand Down Expand Up @@ -203,9 +201,9 @@ object HiveTypeCoercion {

def castOutput(plan: LogicalPlan): LogicalPlan = {
val casted = plan.output.zip(castedTypes).map {
case (hs, Some(dt)) if dt != hs.dataType =>
Alias(Cast(hs, dt), hs.name)()
case (hs, _) => hs
case (e, Some(dt)) if e.dataType != dt =>
Alias(Cast(e, dt), e.name)()
case (e, _) => e
}
Project(casted, plan)
}
Expand Down Expand Up @@ -355,20 +353,8 @@ object HiveTypeCoercion {
DecimalType.bounded(range + scale, scale)
}

/**
* An expression used to wrap the children when promote the precision of DecimalType to avoid
* promote multiple times.
*/
case class ChangePrecision(child: Expression) extends UnaryExpression {
override def dataType: DataType = child.dataType
override def eval(input: InternalRow): Any = child.eval(input)
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
override def prettyName: String = "change_precision"
}

def changePrecision(e: Expression, dataType: DataType): Expression = {
ChangePrecision(Cast(e, dataType))
private def changePrecision(e: Expression, dataType: DataType): Expression = {
ChangeDecimalPrecision(Cast(e, dataType))
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
Expand All @@ -378,7 +364,7 @@ object HiveTypeCoercion {
case e if !e.childrenResolved => e

// Skip nodes who is already promoted
case e: BinaryArithmetic if e.left.isInstanceOf[ChangePrecision] => e
case e: BinaryArithmetic if e.left.isInstanceOf[ChangeDecimalPrecision] => e

case Add(e1 @ DecimalType.Expression(p1, s1), e2 @ DecimalType.Expression(p2, s2)) =>
val dt = DecimalType.bounded(max(s1, s2) + max(p1 - s1, p2 - s2) + 1, max(s1, s2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
import org.apache.spark.sql.types._

Expand Down Expand Up @@ -60,3 +61,15 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
})
}
}

/**
* An expression used to wrap the children when promote the precision of DecimalType to avoid
* promote multiple times.
*/
case class ChangeDecimalPrecision(child: Expression) extends UnaryExpression {
override def dataType: DataType = child.dataType
override def eval(input: InternalRow): Any = child.eval(input)
override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx)
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = ""
override def prettyName: String = "change_decimal_precision"
}
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ case class DecimalType(precision: Int, scale: Int) extends FractionalType {

override def toString: String = s"DecimalType($precision,$scale)"

/**
* Returns whether this DecimalType is wider than `other`. If yes, it means `other`
* can be casted into `this` safely without losing any precision or range.
*/
private[sql] def isWiderThan(other: DataType): Boolean = other match {
case dt: DecimalType =>
(precision - scale) >= (dt.precision - dt.scale) && scale >= dt.scale
Expand Down Expand Up @@ -109,7 +113,7 @@ object DecimalType extends AbstractDataType {
@deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5")
val Unlimited: DecimalType = SYSTEM_DEFAULT

// The decimal types compatible with other numberic types
// The decimal types compatible with other numeric types
private[sql] val ByteDecimal = DecimalType(3, 0)
private[sql] val ShortDecimal = DecimalType(5, 0)
private[sql] val IntDecimal = DecimalType(10, 0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,4 +154,30 @@ class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter {
checkType(Remainder(expr, u), DoubleType)
}
}

test("DecimalType.isWiderThan") {
val d0 = DecimalType(2, 0)
val d1 = DecimalType(2, 1)
val d2 = DecimalType(5, 2)
val d3 = DecimalType(15, 3)
val d4 = DecimalType(25, 4)

assert(d0.isWiderThan(d1) === false)
assert(d1.isWiderThan(d0) === false)
assert(d1.isWiderThan(d2) === false)
assert(d2.isWiderThan(d1) === true)
assert(d2.isWiderThan(d3) === false)
assert(d3.isWiderThan(d2) === true)
assert(d4.isWiderThan(d3) === true)

assert(d1.isWiderThan(ByteType) === false)
assert(d2.isWiderThan(ByteType) === true)
assert(d2.isWiderThan(ShortType) === false)
assert(d3.isWiderThan(ShortType) === true)
assert(d3.isWiderThan(IntegerType) === true)
assert(d3.isWiderThan(LongType) === false)
assert(d4.isWiderThan(LongType) === true)
assert(d4.isWiderThan(FloatType) === false)
assert(d4.isWiderThan(DoubleType) === false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class HiveTypeCoercionSuite extends PlanTest {
)
}

test("WidenTypes for union except and intersect") {
test("WidenSetOperationTypes for union except and intersect") {
def checkOutput(logical: LogicalPlan, expectTypes: Seq[DataType]): Unit = {
logical.output.zip(expectTypes).foreach { case (attr, dt) =>
assert(attr.dataType === dt)
Expand All @@ -324,7 +324,7 @@ class HiveTypeCoercionSuite extends PlanTest {
AttributeReference("f", FloatType)(),
AttributeReference("l", LongType)())

val wt = HiveTypeCoercion.WidenTypes
val wt = HiveTypeCoercion.WidenSetOperationTypes
val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType)

val r1 = wt(Union(left, right)).asInstanceOf[Union]
Expand All @@ -345,7 +345,7 @@ class HiveTypeCoercionSuite extends PlanTest {
}
}

val dp = HiveTypeCoercion.WidenTypes
val dp = HiveTypeCoercion.WidenSetOperationTypes

val left1 = LocalRelation(
AttributeReference("l", DecimalType(10, 8))())
Expand Down