Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -669,10 +669,10 @@ object HiveTypeCoercion {
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
if (b.inputType.acceptsType(commonType)) {
// If the expression accepts the tighest common type, cast to that.
// If the expression accepts the tightest common type, cast to that.
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
b.makeCopy(Array(newLeft, newRight))
b.withNewChildren(Seq(newLeft, newRight))
} else {
// Otherwise, don't do anything with the expression.
b
Expand All @@ -691,7 +691,7 @@ object HiveTypeCoercion {
// general implicit casting.
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
if (in.dataType == NullType && !expected.acceptsType(NullType)) {
Cast(in, expected.defaultConcreteType)
Literal.create(null, expected.defaultConcreteType)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast NullType to any type will result a literal null, so we can just write down the literal null here.

} else {
in
}
Expand All @@ -713,27 +713,22 @@ object HiveTypeCoercion {
@Nullable val ret: Expression = (inType, expectedType) match {

// If the expected type is already a parent of the input type, no need to cast.
case _ if expectedType.isSameType(inType) => e
case _ if expectedType.acceptsType(inType) => e
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need isSameType? I think the logic should be:

  1. first check whether the input type is acceptable(i.e. expectedType.acceptsType(inType) returns true). If it is, do nothing here.
  2. If input type is not acceptable, follow cast rules below to do implicit type cast.

It looks to me we only need acceptsType.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there was a reason, but i can't remember why right now. let me think about it.

cc @marmbrus do you remember why?


// Cast null type (usually from null literals) into target types
case (NullType, target) => Cast(e, target.defaultConcreteType)

// If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is
// already a number, leave it as is.
case (_: NumericType, NumericType) => e

// If the function accepts any numeric type and the input is a string, we follow the hive
// convention and cast that input into a double
case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType)

// Implicit cast among numeric types
// Implicit cast among numeric types. When we reach here, input type is not acceptable.

// If input is a numeric type but not decimal, and we expect a decimal type,
// cast the input to unlimited precision decimal.
case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] =>
Cast(e, DecimalType.Unlimited)
case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited)
// For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long
case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target)
case (_: NumericType, target: NumericType) => e
case (_: NumericType, target: NumericType) => Cast(e, target)

// Implicit cast between date time types
case (DateType, TimestampType) => Cast(e, TimestampType)
Expand All @@ -747,15 +742,9 @@ object HiveTypeCoercion {
case (StringType, BinaryType) => Cast(e, BinaryType)
case (any, StringType) if any != StringType => Cast(e, StringType)

// Type collection.
// First see if we can find our input type in the type collection. If we can, then just
// use the current expression; otherwise, find the first one we can implicitly cast.
case (_, TypeCollection(types)) =>
if (types.exists(_.isSameType(inType))) {
e
} else {
types.flatMap(implicitCast(e, _)).headOption.orNull
}
// When we reach here, input type is not acceptable for any types in this type collection,
// try to find the first one we can implicitly cast.
case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct? Imagine we have TypeCollection(LongType, StringType), and the input is StringType

Wouldn't this just cast input to a longtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the first rule is: case _ if expectedType.acceptsType(inType) => e.
So when we reach here, input type is not acceptable for any types in this type collection. see the inline comments above.


// Else, just return the same input expression
case _ => null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -369,17 +369,15 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)

override def checkInputDataTypes(): TypeCheckResult = {
// First call the checker for ExpectsInputTypes, and then check whether left and right have
// the same type.
super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
if (left.dataType != right.dataType) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
} else {
TypeCheckResult.TypeCheckSuccess
}
case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg)
// First check whether left and right have the same type, then check if the type is acceptable.
if (left.dataType != right.dataType) {
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
} else if (!inputType.acceptsType(left.dataType)) {
TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," +
s" not ${left.dataType.simpleString}")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to generate a different error message for BinaryOperator, as 2 children have same type and same expected type, we don't need to follow ExpectsInputTypes to report type error for both left and right.

} else {
TypeCheckResult.TypeCheckSuccess
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
}

override def symbol: String = "max"
override def prettyName: String = symbol
}

case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
Expand Down Expand Up @@ -375,5 +374,4 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
}

override def symbol: String = "min"
override def prettyName: String = symbol
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.types._
*/
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = TypeCollection.Bitwise
override def inputType: AbstractDataType = IntegralType

override def symbol: String = "&"

Expand All @@ -53,7 +53,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
*/
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = TypeCollection.Bitwise
override def inputType: AbstractDataType = IntegralType

override def symbol: String = "|"

Expand All @@ -78,7 +78,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
*/
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {

override def inputType: AbstractDataType = TypeCollection.Bitwise
override def inputType: AbstractDataType = IntegralType

override def symbol: String = "^"

Expand All @@ -101,7 +101,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
*/
case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.Bitwise)
override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType)

override def dataType: DataType = child.dataType

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
TypeCheckResult.TypeCheckFailure(
s"type of predicate expression in If should be boolean, not ${predicate.dataType}")
} else if (trueValue.dataType != falseValue.dataType) {
TypeCheckResult.TypeCheckFailure(
s"differing types in If (${trueValue.dataType} and ${falseValue.dataType}).")
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
s"(${trueValue.dataType.simpleString} and ${falseValue.dataType.simpleString}).")
} else {
TypeCheckResult.TypeCheckSuccess
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType {
private[sql] def defaultConcreteType: DataType

/**
* Returns true if this data type is the same type as `other`. This is different that equality
* as equality will also consider data type parametrization, such as decimal precision.
* Returns true if `other` is an acceptable input type for a function that expects this,
* possibly abstract DataType.
*
* {{{
* // this should return true
* DecimalType.isSameType(DecimalType(10, 2))
*
* // this should return false
* NumericType.isSameType(DecimalType(10, 2))
* }}}
*/
private[sql] def isSameType(other: DataType): Boolean

/**
* Returns true if `other` is an acceptable input type for a function that expectes this,
* possibly abstract, DataType.
*
* {{{
* // this should return true
* DecimalType.isSameType(DecimalType(10, 2))
* DecimalType.acceptsType(DecimalType(10, 2))
*
* // this should return true as well
* NumericType.acceptsType(DecimalType(10, 2))
* }}}
*/
private[sql] def acceptsType(other: DataType): Boolean = isSameType(other)
private[sql] def acceptsType(other: DataType): Boolean

/** Readable string representation for the type. */
private[sql] def simpleString: String
Expand All @@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])

override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType

override private[sql] def isSameType(other: DataType): Boolean = false

override private[sql] def acceptsType(other: DataType): Boolean =
types.exists(_.isSameType(other))
types.exists(_.acceptsType(other))

override private[sql] def simpleString: String = {
types.map(_.simpleString).mkString("(", " or ", ")")
Expand All @@ -107,13 +91,6 @@ private[sql] object TypeCollection {
TimestampType, DateType,
StringType, BinaryType)

/**
* Types that can be used in bitwise operations.
*/
val Bitwise = TypeCollection(
BooleanType,
ByteType, ShortType, IntegerType, LongType)

def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)

def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {
Expand All @@ -134,8 +111,6 @@ protected[sql] object AnyDataType extends AbstractDataType {

override private[sql] def simpleString: String = "any"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked with hive, hive's bitwise operations only accept integral type, boolean is not supported.

override private[sql] def isSameType(other: DataType): Boolean = false

override private[sql] def acceptsType(other: DataType): Boolean = true
}

Expand Down Expand Up @@ -183,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType {

override private[sql] def simpleString: String = "numeric"

override private[sql] def isSameType(other: DataType): Boolean = false

override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType]
}


private[sql] object IntegralType {
private[sql] object IntegralType extends AbstractDataType {
/**
* Enables matching against IntegralType for expressions:
* {{{
Expand All @@ -198,6 +171,12 @@ private[sql] object IntegralType {
* }}}
*/
def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType]

override private[sql] def defaultConcreteType: DataType = IntegerType

override private[sql] def simpleString: String = "integral"

override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType]
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ object ArrayType extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType = ArrayType(NullType, containsNull = true)

override private[sql] def isSameType(other: DataType): Boolean = {
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[ArrayType]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ abstract class DataType extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType = this

override private[sql] def isSameType(other: DataType): Boolean = this == other
override private[sql] def acceptsType(other: DataType): Boolean = this == other
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ object DecimalType extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType = Unlimited

override private[sql] def isSameType(other: DataType): Boolean = {
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[DecimalType]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ object MapType extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType = apply(NullType, NullType)

override private[sql] def isSameType(other: DataType): Boolean = {
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[MapType]
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ object StructType extends AbstractDataType {

override private[sql] def defaultConcreteType: DataType = new StructType

override private[sql] def isSameType(other: DataType): Boolean = {
override private[sql] def acceptsType(other: DataType): Boolean = {
other.isInstanceOf[StructType]
}

Expand Down
Loading