Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -130,28 +130,28 @@ trait HiveTypeCoercion {
* the appropriate numeric equivalent.
*/
object ConvertNaNs extends Rule[LogicalPlan] {
private val stringNaN = Literal("NaN")
private val StringNaN = Literal("NaN")

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressions {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

/* Double Conversions */
case b: BinaryExpression if b.left == stringNaN && b.right.dataType == DoubleType =>
b.makeCopy(Array(b.right, Literal(Double.NaN)))
case b: BinaryExpression if b.left.dataType == DoubleType && b.right == stringNaN =>
b.makeCopy(Array(Literal(Double.NaN), b.left))
case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN =>
b.makeCopy(Array(Literal(Double.NaN), b.left))
case b @ BinaryExpression(StringNaN, right @ DoubleType()) =>
b.makeCopy(Array(Literal(Double.NaN), right))
case b @ BinaryExpression(left @ DoubleType(), StringNaN) =>
b.makeCopy(Array(left, Literal(Double.NaN)))

/* Float Conversions */
case b: BinaryExpression if b.left == stringNaN && b.right.dataType == FloatType =>
b.makeCopy(Array(b.right, Literal(Float.NaN)))
case b: BinaryExpression if b.left.dataType == FloatType && b.right == stringNaN =>
b.makeCopy(Array(Literal(Float.NaN), b.left))
case b: BinaryExpression if b.left == stringNaN && b.right == stringNaN =>
b.makeCopy(Array(Literal(Float.NaN), b.left))
case b @ BinaryExpression(StringNaN, right @ FloatType()) =>
b.makeCopy(Array(Literal(Float.NaN), right))
case b @ BinaryExpression(left @ FloatType(), StringNaN) =>
b.makeCopy(Array(left, Literal(Float.NaN)))

/* Use float NaN by default to avoid unnecessary type widening */
case b @ BinaryExpression(left @ StringNaN, StringNaN) =>
b.makeCopy(Array(left, Literal(Float.NaN)))
}
}
}
Expand Down Expand Up @@ -184,21 +184,25 @@ trait HiveTypeCoercion {
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val castedInput = left.output.zip(right.output).map {
// When a string is found on one side, make the other side a string too.
case (l, r) if l.dataType == StringType && r.dataType != StringType =>
(l, Alias(Cast(r, StringType), r.name)())
case (l, r) if l.dataType != StringType && r.dataType == StringType =>
(Alias(Cast(l, StringType), l.name)(), r)

case (l, r) if l.dataType != r.dataType =>
logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}")
findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType =>
case (lhs, rhs) if lhs.dataType == StringType && rhs.dataType != StringType =>
(lhs, Alias(Cast(rhs, StringType), rhs.name)())
case (lhs, rhs) if lhs.dataType != StringType && rhs.dataType == StringType =>
(Alias(Cast(lhs, StringType), lhs.name)(), rhs)

case (lhs, rhs) if lhs.dataType != rhs.dataType =>
logDebug(s"Resolving mismatched union input ${lhs.dataType}, ${rhs.dataType}")
findTightestCommonTypeOfTwo(lhs.dataType, rhs.dataType).map { widestType =>
val newLeft =
if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)()
if (lhs.dataType == widestType) lhs else Alias(Cast(lhs, widestType), lhs.name)()
val newRight =
if (r.dataType == widestType) r else Alias(Cast(r, widestType), r.name)()
if (rhs.dataType == widestType) rhs else Alias(Cast(rhs, widestType), rhs.name)()

(newLeft, newRight)
}.getOrElse((l, r)) // If there is no applicable conversion, leave expression unchanged.
}.getOrElse {
// If there is no applicable conversion, leave expression unchanged.
(lhs, rhs)
}

case other => other
}

Expand Down Expand Up @@ -227,12 +231,10 @@ trait HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case b: BinaryExpression if b.left.dataType != b.right.dataType =>
findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType =>
val newLeft =
if (b.left.dataType == widestType) b.left else Cast(b.left, widestType)
val newRight =
if (b.right.dataType == widestType) b.right else Cast(b.right, widestType)
case b @ BinaryExpression(left, right) if left.dataType != right.dataType =>
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType =>
val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
b.makeCopy(Array(newLeft, newRight))
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
}
Expand All @@ -247,57 +249,42 @@ trait HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e

case a: BinaryArithmetic if a.left.dataType == StringType =>
a.makeCopy(Array(Cast(a.left, DoubleType), a.right))
case a: BinaryArithmetic if a.right.dataType == StringType =>
a.makeCopy(Array(a.left, Cast(a.right, DoubleType)))
case a @ BinaryArithmetic(left @ StringType(), r) =>
a.makeCopy(Array(Cast(left, DoubleType), r))
case a @ BinaryArithmetic(left, right @ StringType()) =>
a.makeCopy(Array(left, Cast(right, DoubleType)))

// we should cast all timestamp/date/string compare into string compare
case p: BinaryComparison if p.left.dataType == StringType &&
p.right.dataType == DateType =>
p.makeCopy(Array(p.left, Cast(p.right, StringType)))
case p: BinaryComparison if p.left.dataType == DateType &&
p.right.dataType == StringType =>
p.makeCopy(Array(Cast(p.left, StringType), p.right))
case p: BinaryComparison if p.left.dataType == StringType &&
p.right.dataType == TimestampType =>
p.makeCopy(Array(Cast(p.left, TimestampType), p.right))
case p: BinaryComparison if p.left.dataType == TimestampType &&
p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, TimestampType)))
case p: BinaryComparison if p.left.dataType == TimestampType &&
p.right.dataType == DateType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))
case p: BinaryComparison if p.left.dataType == DateType &&
p.right.dataType == TimestampType =>
p.makeCopy(Array(Cast(p.left, StringType), Cast(p.right, StringType)))

case p: BinaryComparison if p.left.dataType == StringType &&
p.right.dataType != StringType =>
p.makeCopy(Array(Cast(p.left, DoubleType), p.right))
case p: BinaryComparison if p.left.dataType != StringType &&
p.right.dataType == StringType =>
p.makeCopy(Array(p.left, Cast(p.right, DoubleType)))

case i @ In(a, b) if a.dataType == DateType &&
b.forall(_.dataType == StringType) =>
case p @ BinaryComparison(left @ StringType(), right @ DateType()) =>
p.makeCopy(Array(left, Cast(right, StringType)))
case p @ BinaryComparison(left @ DateType(), right @ StringType()) =>
p.makeCopy(Array(Cast(left, StringType), right))
case p @ BinaryComparison(left @ StringType(), right @ TimestampType()) =>
p.makeCopy(Array(Cast(left, TimestampType), right))
case p @ BinaryComparison(left @ TimestampType(), right @ StringType()) =>
p.makeCopy(Array(left, Cast(right, TimestampType)))
case p @ BinaryComparison(left @ TimestampType(), right @ DateType()) =>
p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))
case p @ BinaryComparison(left @ DateType(), right @ TimestampType()) =>
p.makeCopy(Array(Cast(left, StringType), Cast(right, StringType)))

case p @ BinaryComparison(left @ StringType(), right) if right.dataType != StringType =>
p.makeCopy(Array(Cast(left, DoubleType), right))
case p @ BinaryComparison(left, right @ StringType()) if left.dataType != StringType =>
p.makeCopy(Array(left, Cast(right, DoubleType)))

case i @ In(a @ DateType(), b) if b.forall(_.dataType == StringType) =>
i.makeCopy(Array(Cast(a, StringType), b))
case i @ In(a, b) if a.dataType == TimestampType &&
b.forall(_.dataType == StringType) =>
case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == StringType) =>
i.makeCopy(Array(a, b.map(Cast(_, TimestampType))))
case i @ In(a, b) if a.dataType == DateType &&
b.forall(_.dataType == TimestampType) =>
case i @ In(a @ DateType(), b) if b.forall(_.dataType == TimestampType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))
case i @ In(a, b) if a.dataType == TimestampType &&
b.forall(_.dataType == DateType) =>
case i @ In(a @ TimestampType(), b) if b.forall(_.dataType == DateType) =>
i.makeCopy(Array(Cast(a, StringType), b.map(Cast(_, StringType))))

case Sum(e) if e.dataType == StringType =>
Sum(Cast(e, DoubleType))
case Average(e) if e.dataType == StringType =>
Average(Cast(e, DoubleType))
case Sqrt(e) if e.dataType == StringType =>
Sqrt(Cast(e, DoubleType))
case Sum(e @ StringType()) => Sum(Cast(e, DoubleType))
case Average(e @ StringType()) => Average(Cast(e, DoubleType))
case Sqrt(e @ StringType()) => Sqrt(Cast(e, DoubleType))
}
}

Expand Down Expand Up @@ -379,22 +366,22 @@ trait HiveTypeCoercion {
// fix decimal precision for union
case u @ Union(left, right) if u.childrenResolved && !u.resolved =>
val castedInput = left.output.zip(right.output).map {
case (l, r) if l.dataType != r.dataType =>
(l.dataType, r.dataType) match {
case (lhs, rhs) if lhs.dataType != rhs.dataType =>
(lhs.dataType, rhs.dataType) match {
case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
// Union decimals with precision/scale p1/s2 and p2/s2 will be promoted to
// DecimalType(max(s1, s2) + max(p1-s1, p2-s2), max(s1, s2))
val fixedType = DecimalType(max(s1, s2) + max(p1 - s1, p2 - s2), max(s1, s2))
(Alias(Cast(l, fixedType), l.name)(), Alias(Cast(r, fixedType), r.name)())
(Alias(Cast(lhs, fixedType), lhs.name)(), Alias(Cast(rhs, fixedType), rhs.name)())
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
(Alias(Cast(l, intTypeToFixed(t)), l.name)(), r)
(Alias(Cast(lhs, intTypeToFixed(t)), lhs.name)(), rhs)
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
(l, Alias(Cast(r, intTypeToFixed(t)), r.name)())
(lhs, Alias(Cast(rhs, intTypeToFixed(t)), rhs.name)())
case (t, DecimalType.Fixed(p, s)) if floatTypeToFixed.contains(t) =>
(Alias(Cast(l, floatTypeToFixed(t)), l.name)(), r)
(Alias(Cast(lhs, floatTypeToFixed(t)), lhs.name)(), rhs)
case (DecimalType.Fixed(p, s), t) if floatTypeToFixed.contains(t) =>
(l, Alias(Cast(r, floatTypeToFixed(t)), r.name)())
case _ => (l, r)
(lhs, Alias(Cast(rhs, floatTypeToFixed(t)), rhs.name)())
case _ => (lhs, rhs)
}
case other => other
}
Expand Down Expand Up @@ -467,16 +454,16 @@ trait HiveTypeCoercion {

// Promote integers inside a binary expression with fixed-precision decimals to decimals,
// and fixed-precision decimals in an expression with floats / doubles to doubles
case b: BinaryExpression if b.left.dataType != b.right.dataType =>
(b.left.dataType, b.right.dataType) match {
case b @ BinaryExpression(left, right) if left.dataType != right.dataType =>
(left.dataType, right.dataType) match {
case (t, DecimalType.Fixed(p, s)) if intTypeToFixed.contains(t) =>
b.makeCopy(Array(Cast(b.left, intTypeToFixed(t)), b.right))
b.makeCopy(Array(Cast(left, intTypeToFixed(t)), right))
case (DecimalType.Fixed(p, s), t) if intTypeToFixed.contains(t) =>
b.makeCopy(Array(b.left, Cast(b.right, intTypeToFixed(t))))
b.makeCopy(Array(left, Cast(right, intTypeToFixed(t))))
case (t, DecimalType.Fixed(p, s)) if isFloat(t) =>
b.makeCopy(Array(b.left, Cast(b.right, DoubleType)))
b.makeCopy(Array(left, Cast(right, DoubleType)))
case (DecimalType.Fixed(p, s), t) if isFloat(t) =>
b.makeCopy(Array(Cast(b.left, DoubleType), b.right))
b.makeCopy(Array(Cast(left, DoubleType), right))
case _ =>
b
}
Expand Down Expand Up @@ -525,31 +512,31 @@ trait HiveTypeCoercion {
// all other cases are considered as false.

// We may simplify the expression if one side is literal numeric values
case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => l
case EqualTo(l @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => Not(l)
case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
if trueValues.contains(value) => r
case EqualTo(Literal(value, _: NumericType), r @ BooleanType())
if falseValues.contains(value) => Not(r)
case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => And(IsNotNull(l), l)
case EqualNullSafe(l @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => And(IsNotNull(l), Not(l))
case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
if trueValues.contains(value) => And(IsNotNull(r), r)
case EqualNullSafe(Literal(value, _: NumericType), r @ BooleanType())
if falseValues.contains(value) => And(IsNotNull(r), Not(r))

case EqualTo(l @ BooleanType(), r @ NumericType()) =>
transform(l , r)
case EqualTo(l @ NumericType(), r @ BooleanType()) =>
transform(r, l)
case EqualNullSafe(l @ BooleanType(), r @ NumericType()) =>
transformNullSafe(l, r)
case EqualNullSafe(l @ NumericType(), r @ BooleanType()) =>
transformNullSafe(r, l)
case EqualTo(left @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => left
case EqualTo(left @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => Not(left)
case EqualTo(Literal(value, _: NumericType), right @ BooleanType())
if trueValues.contains(value) => right
case EqualTo(Literal(value, _: NumericType), right @ BooleanType())
if falseValues.contains(value) => Not(right)
case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType))
if trueValues.contains(value) => And(IsNotNull(left), left)
case EqualNullSafe(left @ BooleanType(), Literal(value, _: NumericType))
if falseValues.contains(value) => And(IsNotNull(left), Not(left))
case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType())
if trueValues.contains(value) => And(IsNotNull(right), right)
case EqualNullSafe(Literal(value, _: NumericType), right @ BooleanType())
if falseValues.contains(value) => And(IsNotNull(right), Not(right))

case EqualTo(left @ BooleanType(), right @ NumericType()) =>
transform(left , right)
case EqualTo(left @ NumericType(), right @ BooleanType()) =>
transform(right, left)
case EqualNullSafe(left @ BooleanType(), right @ NumericType()) =>
transformNullSafe(left, right)
case EqualNullSafe(left @ NumericType(), right @ BooleanType()) =>
transformNullSafe(right, left)
}
}

Expand Down Expand Up @@ -630,7 +617,7 @@ trait HiveTypeCoercion {
case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d

case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
override def toString: String = s"($left $symbol $right)"
}

private[sql] object BinaryExpression {
def unapply(e: BinaryExpression): Option[(Expression, Expression)] = Some((e.left, e.right))
}

abstract class LeafExpression extends Expression with trees.LeafNode[Expression] {
self: Product =>
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,10 @@ abstract class BinaryArithmetic extends BinaryExpression {
sys.error(s"BinaryArithmetics must override either eval or evalInternal")
}

private[sql] object BinaryArithmetic {
def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
}

case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
override def symbol: String = "+"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,8 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
sys.error(s"BinaryComparisons must override either eval or evalInternal")
}

object BinaryComparison {
def unapply(b: BinaryComparison): Option[(Expression, Expression)] =
Some((b.left, b.right))
private[sql] object BinaryComparison {
def unapply(e: BinaryComparison): Option[(Expression, Expression)] = Some((e.left, e.right))
}

case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
Expand Down
Loading