diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 13cc9b9c125e..a5c15dae5084 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -113,12 +113,21 @@ object TypeCoercion { case _ => None } + private def findCommonTypeForBinaryComparison( + dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = { + if (conf.isHiveTypeCoercionMode) { + findHiveCommonTypeForBinary(dt1, dt2) + } else { + findNativeCommonTypeForBinary(dt1, dt2, conf) + } + } + /** * This function determines the target type of a comparison operator when one operand * is a String and the other is not. It also handles when one op is a Date and the * other is a Timestamp by making the target type to be String. */ - private def findCommonTypeForBinaryComparison( + private def findNativeCommonTypeForBinary( dt1: DataType, dt2: DataType, conf: SQLConf): Option[DataType] = (dt1, dt2) match { // We should cast all relative timestamp/date/string comparison into string comparisons // This behaves as a user would expect because timestamp strings sort lexicographically. @@ -204,6 +213,28 @@ object TypeCoercion { } } + /** + * This function follow hive's binary comparison action: + * https://github.com/apache/hive/blob/rel/release-3.0.0/ql/src/java/ + * org/apache/hadoop/hive/ql/exec/FunctionRegistry.java#L802 + */ + private def findHiveCommonTypeForBinary( + dt1: DataType, dt2: DataType): Option[DataType] = (dt1, dt2) match { + case (StringType, DateType) => Some(DateType) + case (DateType, StringType) => Some(DateType) + case (StringType, TimestampType) => Some(TimestampType) + case (TimestampType, StringType) => Some(TimestampType) + case (TimestampType, DateType) => Some(TimestampType) + case (DateType, TimestampType) => Some(TimestampType) + case (StringType, NullType) => Some(StringType) + case (NullType, StringType) => Some(StringType) + case (StringType | TimestampType, r: NumericType) => Some(DoubleType) + case (l: NumericType, StringType | TimestampType) => Some(DoubleType) + case (l: StringType, r: AtomicType) if r != StringType => Some(r) + case (l: AtomicType, r: StringType) if l != StringType => Some(l) + case _ => None + } + /** * Case 2 type widening (see the classdoc comment above for TypeCoercion). * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 6b301c3c9cb5..ae357a0bd38c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1334,6 +1334,15 @@ object SQLConf { .booleanConf .createWithDefault(false) + val typeCoercionMode = + buildConf("spark.sql.typeCoercion.mode") + .doc("Since Spark 2.4, the 'hive' mode is introduced for Hive compatiblity. " + + "Spark SQL has its native type cocersion mode, which is enabled by default.") + .stringConf + .transform(_.toLowerCase(Locale.ROOT)) + .checkValues(Set("default", "hive")) + .createWithDefault("default") + val REPLACE_EXCEPT_WITH_FILTER = buildConf("spark.sql.optimizer.replaceExceptWithFilter") .internal() .doc("When true, the apply function of the rule verifies whether the right node of the" + @@ -2010,6 +2019,8 @@ class SQLConf extends Serializable with Logging { def arrowSafeTypeConversion: Boolean = getConf(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION) + def isHiveTypeCoercionMode: Boolean = getConf(SQLConf.typeCoercionMode).equals("hive") + def replaceExceptWithFilter: Boolean = getConf(REPLACE_EXCEPT_WITH_FILTER) def decimalOperationsAllowPrecisionLoss: Boolean = getConf(DECIMAL_OPERATIONS_ALLOW_PREC_LOSS)