-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-9068][SQL] refactor the implicit type cast code #7420
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -669,10 +669,10 @@ object HiveTypeCoercion { | |
| case b @ BinaryOperator(left, right) if left.dataType != right.dataType => | ||
| findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType => | ||
| if (b.inputType.acceptsType(commonType)) { | ||
| // If the expression accepts the tighest common type, cast to that. | ||
| // If the expression accepts the tightest common type, cast to that. | ||
| val newLeft = if (left.dataType == commonType) left else Cast(left, commonType) | ||
| val newRight = if (right.dataType == commonType) right else Cast(right, commonType) | ||
| b.makeCopy(Array(newLeft, newRight)) | ||
| b.withNewChildren(Seq(newLeft, newRight)) | ||
| } else { | ||
| // Otherwise, don't do anything with the expression. | ||
| b | ||
|
|
@@ -691,7 +691,7 @@ object HiveTypeCoercion { | |
| // general implicit casting. | ||
| val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) => | ||
| if (in.dataType == NullType && !expected.acceptsType(NullType)) { | ||
| Cast(in, expected.defaultConcreteType) | ||
| Literal.create(null, expected.defaultConcreteType) | ||
| } else { | ||
| in | ||
| } | ||
|
|
@@ -713,27 +713,22 @@ object HiveTypeCoercion { | |
| @Nullable val ret: Expression = (inType, expectedType) match { | ||
|
|
||
| // If the expected type is already a parent of the input type, no need to cast. | ||
| case _ if expectedType.isSameType(inType) => e | ||
| case _ if expectedType.acceptsType(inType) => e | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we need
It looks to me we only need
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there was a reason, but i can't remember why right now. let me think about it. cc @marmbrus do you remember why? |
||
|
|
||
| // Cast null type (usually from null literals) into target types | ||
| case (NullType, target) => Cast(e, target.defaultConcreteType) | ||
|
|
||
| // If the function accepts any numeric type (i.e. the ADT `NumericType`) and the input is | ||
| // already a number, leave it as is. | ||
| case (_: NumericType, NumericType) => e | ||
|
|
||
| // If the function accepts any numeric type and the input is a string, we follow the hive | ||
| // convention and cast that input into a double | ||
| case (StringType, NumericType) => Cast(e, NumericType.defaultConcreteType) | ||
|
|
||
| // Implicit cast among numeric types | ||
| // Implicit cast among numeric types. When we reach here, input type is not acceptable. | ||
|
|
||
| // If input is a numeric type but not decimal, and we expect a decimal type, | ||
| // cast the input to unlimited precision decimal. | ||
| case (_: NumericType, DecimalType) if !inType.isInstanceOf[DecimalType] => | ||
| Cast(e, DecimalType.Unlimited) | ||
| case (_: NumericType, DecimalType) => Cast(e, DecimalType.Unlimited) | ||
| // For any other numeric types, implicitly cast to each other, e.g. long -> int, int -> long | ||
| case (_: NumericType, target: NumericType) if e.dataType != target => Cast(e, target) | ||
| case (_: NumericType, target: NumericType) => e | ||
| case (_: NumericType, target: NumericType) => Cast(e, target) | ||
|
|
||
| // Implicit cast between date time types | ||
| case (DateType, TimestampType) => Cast(e, TimestampType) | ||
|
|
@@ -747,15 +742,9 @@ object HiveTypeCoercion { | |
| case (StringType, BinaryType) => Cast(e, BinaryType) | ||
| case (any, StringType) if any != StringType => Cast(e, StringType) | ||
|
|
||
| // Type collection. | ||
| // First see if we can find our input type in the type collection. If we can, then just | ||
| // use the current expression; otherwise, find the first one we can implicitly cast. | ||
| case (_, TypeCollection(types)) => | ||
| if (types.exists(_.isSameType(inType))) { | ||
| e | ||
| } else { | ||
| types.flatMap(implicitCast(e, _)).headOption.orNull | ||
| } | ||
| // When we reach here, input type is not acceptable for any types in this type collection, | ||
| // try to find the first one we can implicitly cast. | ||
| case (_, TypeCollection(types)) => types.flatMap(implicitCast(e, _)).headOption.orNull | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this correct? Imagine we have TypeCollection(LongType, StringType), and the input is StringType Wouldn't this just cast input to a longtype?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the first rule is: |
||
|
|
||
| // Else, just return the same input expression | ||
| case _ => null | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -369,17 +369,15 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { | |
| override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType) | ||
|
|
||
| override def checkInputDataTypes(): TypeCheckResult = { | ||
| // First call the checker for ExpectsInputTypes, and then check whether left and right have | ||
| // the same type. | ||
| super.checkInputDataTypes() match { | ||
| case TypeCheckResult.TypeCheckSuccess => | ||
| if (left.dataType != right.dataType) { | ||
| TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + | ||
| s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") | ||
| } else { | ||
| TypeCheckResult.TypeCheckSuccess | ||
| } | ||
| case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg) | ||
| // First check whether left and right have the same type, then check if the type is acceptable. | ||
| if (left.dataType != right.dataType) { | ||
| TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " + | ||
| s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).") | ||
| } else if (!inputType.acceptsType(left.dataType)) { | ||
| TypeCheckResult.TypeCheckFailure(s"'$prettyString' accepts ${inputType.simpleString} type," + | ||
| s" not ${left.dataType.simpleString}") | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's better to generate a different error message for |
||
| } else { | ||
| TypeCheckResult.TypeCheckSuccess | ||
| } | ||
| } | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,32 +34,18 @@ private[sql] abstract class AbstractDataType { | |
| private[sql] def defaultConcreteType: DataType | ||
|
|
||
| /** | ||
| * Returns true if this data type is the same type as `other`. This is different that equality | ||
| * as equality will also consider data type parametrization, such as decimal precision. | ||
| * Returns true if `other` is an acceptable input type for a function that expects this, | ||
| * possibly abstract DataType. | ||
| * | ||
| * {{{ | ||
| * // this should return true | ||
| * DecimalType.isSameType(DecimalType(10, 2)) | ||
| * | ||
| * // this should return false | ||
| * NumericType.isSameType(DecimalType(10, 2)) | ||
| * }}} | ||
| */ | ||
| private[sql] def isSameType(other: DataType): Boolean | ||
|
|
||
| /** | ||
| * Returns true if `other` is an acceptable input type for a function that expectes this, | ||
| * possibly abstract, DataType. | ||
| * | ||
| * {{{ | ||
| * // this should return true | ||
| * DecimalType.isSameType(DecimalType(10, 2)) | ||
| * DecimalType.acceptsType(DecimalType(10, 2)) | ||
| * | ||
| * // this should return true as well | ||
| * NumericType.acceptsType(DecimalType(10, 2)) | ||
| * }}} | ||
| */ | ||
| private[sql] def acceptsType(other: DataType): Boolean = isSameType(other) | ||
| private[sql] def acceptsType(other: DataType): Boolean | ||
|
|
||
| /** Readable string representation for the type. */ | ||
| private[sql] def simpleString: String | ||
|
|
@@ -83,10 +69,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType]) | |
|
|
||
| override private[sql] def defaultConcreteType: DataType = types.head.defaultConcreteType | ||
|
|
||
| override private[sql] def isSameType(other: DataType): Boolean = false | ||
|
|
||
| override private[sql] def acceptsType(other: DataType): Boolean = | ||
| types.exists(_.isSameType(other)) | ||
| types.exists(_.acceptsType(other)) | ||
|
|
||
| override private[sql] def simpleString: String = { | ||
| types.map(_.simpleString).mkString("(", " or ", ")") | ||
|
|
@@ -107,13 +91,6 @@ private[sql] object TypeCollection { | |
| TimestampType, DateType, | ||
| StringType, BinaryType) | ||
|
|
||
| /** | ||
| * Types that can be used in bitwise operations. | ||
| */ | ||
| val Bitwise = TypeCollection( | ||
| BooleanType, | ||
| ByteType, ShortType, IntegerType, LongType) | ||
|
|
||
| def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types) | ||
|
|
||
| def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match { | ||
|
|
@@ -134,8 +111,6 @@ protected[sql] object AnyDataType extends AbstractDataType { | |
|
|
||
| override private[sql] def simpleString: String = "any" | ||
|
|
||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I checked with hive, hive's bitwise operations only accept integral type, boolean is not supported. |
||
| override private[sql] def isSameType(other: DataType): Boolean = false | ||
|
|
||
| override private[sql] def acceptsType(other: DataType): Boolean = true | ||
| } | ||
|
|
||
|
|
@@ -183,13 +158,11 @@ private[sql] object NumericType extends AbstractDataType { | |
|
|
||
| override private[sql] def simpleString: String = "numeric" | ||
|
|
||
| override private[sql] def isSameType(other: DataType): Boolean = false | ||
|
|
||
| override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[NumericType] | ||
| } | ||
|
|
||
|
|
||
| private[sql] object IntegralType { | ||
| private[sql] object IntegralType extends AbstractDataType { | ||
| /** | ||
| * Enables matching against IntegralType for expressions: | ||
| * {{{ | ||
|
|
@@ -198,6 +171,12 @@ private[sql] object IntegralType { | |
| * }}} | ||
| */ | ||
| def unapply(e: Expression): Boolean = e.dataType.isInstanceOf[IntegralType] | ||
|
|
||
| override private[sql] def defaultConcreteType: DataType = IntegerType | ||
|
|
||
| override private[sql] def simpleString: String = "integral" | ||
|
|
||
| override private[sql] def acceptsType(other: DataType): Boolean = other.isInstanceOf[IntegralType] | ||
| } | ||
|
|
||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cast
NullTypeto any type will result a literal null, so we can just write down the literal null here.