Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,31 @@ object TypeCoercion {
})
}

/**
* Similar to [[findTightestCommonType]], but this handles all numeric types including
* fixed-precision decimals interacting with each other or with primitive types. This will
* not lose precision and scale.
*/
private def findTightestCommonTypeToDecimal(left: DataType, right: DataType): Option[DataType] = {
findTightestCommonTypeOfTwo(left, right).orElse((left, right) match {
case (t1: DecimalType, t2: DecimalType) =>
val scale = math.max(t1.scale, t2.scale)
val range = math.max(t1.precision - t1.scale, t2.precision - t2.scale)
if (range + scale > 38) {
// DecimalType can't support precision > 38
None
} else {
Some(DecimalType(range + scale, scale))
}
case (t1: IntegralType, t2: DecimalType) =>
findTightestCommonTypeToDecimal(DecimalType.forType(t1), t2)
case (t1: DecimalType, t2: IntegralType) =>
findTightestCommonTypeToDecimal(t1, DecimalType.forType(t2))

case _ => None
})
}

/**
* Similar to [[findTightestCommonType]], if can not find the TightestCommonType, try to use
* [[findTightestCommonTypeToString]] to find the TightestCommonType.
Expand All @@ -120,6 +145,18 @@ object TypeCoercion {
})
}

/**
* Similar to [[findTightestCommonType]], Find the tightest common type of a set of types
* by continuously applying `findTightestCommonTypeToDecimal` on these types.
*/
private def findTightestCommonTypeAndPromoteToDecimal(types: Seq[DataType]): Option[DataType] = {
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
case None => None
case Some(d) =>
findTightestCommonTypeToDecimal(d, c)
})
}

/**
* Find the tightest common type of a set of types by continuously applying
* `findTightestCommonTypeOfTwo` on these types.
Expand Down Expand Up @@ -496,14 +533,14 @@ object TypeCoercion {

case g @ Greatest(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
findTightestCommonType(types) match {
findTightestCommonTypeAndPromoteToDecimal(types) match {
case Some(finalDataType) => Greatest(children.map(Cast(_, finalDataType)))
case None => g
}

case l @ Least(children) if !haveSameType(children) =>
val types = children.map(_.dataType)
findTightestCommonType(types) match {
findTightestCommonTypeAndPromoteToDecimal(types) match {
case Some(finalDataType) => Least(children.map(Cast(_, finalDataType)))
case None => l
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
for (operator <- Seq[(Seq[Expression] => Expression)](Greatest, Least)) {
assertError(operator(Seq('booleanField)), "requires at least 2 arguments")
assertError(operator(Seq('intField, 'stringField)), "should all have the same type")
assertError(operator(Seq('intField, 'decimalField)), "should all have the same type")
assertError(operator(Seq('mapField, 'mapField)), "does not support ordering")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,15 @@ class TypeCoercionSuite extends PlanTest {
:: Cast(Literal(1), DecimalType(22, 0))
:: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType(22, 0))
:: Nil))
ruleTest(TypeCoercion.FunctionArgumentConversion,
operator(Literal(1L)
:: Literal(1)
:: Literal(new java.math.BigDecimal("1.5"))
:: Nil),
operator(Cast(Literal(1L), DecimalType(21, 1))
:: Cast(Literal(1), DecimalType(21, 1))
:: Cast(Literal(new java.math.BigDecimal("1.5")), DecimalType(21, 1))
:: Nil))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("conditional function: least - type cast") {
checkAnswer(
testData2.select(least(lit(BigDecimal("-1")), lit(0), col("a"), col("b"))).limit(1),
Row(BigDecimal("-1"))
)
checkAnswer(
sql("SELECT least(a, 1.5) as l from testData2 order by l"),
Seq(
Row(BigDecimal("1.0")),
Row(BigDecimal("1.0")),
Row(BigDecimal("1.5")),
Row(BigDecimal("1.5")),
Row(BigDecimal("1.5")),
Row(BigDecimal("1.5")))
)
}

test("conditional function: least - type cast failure") {
val message = intercept[AnalysisException] {
testData2.select(
least(lit(BigDecimal("0.000000000000000000001")), lit(0L), col("a"), col("b"))).limit(1)
}.message
assert(
message.contains("cannot resolve 'least(CAST(1E-21 AS DECIMAL(21,21)), 0L, `a`, `b`)'" +
" due to data type mismatch") )
}

test("conditional function: greatest") {
checkAnswer(
testData2.select(greatest(lit(2), lit(3), col("a"), col("b"))).limit(1),
Expand All @@ -233,6 +260,33 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("conditional function: greatest - type cast") {
checkAnswer(
testData2.select(greatest(lit(2), lit(BigDecimal("3")), col("a"), col("b"))).limit(1),
Row(BigDecimal("3"))
)
checkAnswer(
sql("SELECT greatest(a, 2.5) as g from testData2 order by g"),
Seq(
Row(BigDecimal("2.5")),
Row(BigDecimal("2.5")),
Row(BigDecimal("2.5")),
Row(BigDecimal("2.5")),
Row(BigDecimal("3")),
Row(BigDecimal("3")))
)
}

test("conditional function: greatest - type cast failure") {
val message = intercept[AnalysisException] {
testData2.select(
greatest(lit(BigDecimal("0.000000000000000000001")), lit(0L), col("a"), col("b")))
}.message
assert(
message.contains("cannot resolve 'greatest(CAST(1E-21 AS DECIMAL(21,21)), 0L, `a`, `b`)'" +
" due to data type mismatch") )
}

test("pmod") {
val intData = Seq((7, 3), (-7, 3)).toDF("a", "b")
checkAnswer(
Expand Down