diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e149bf2f4976..dd9024bf519d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -415,13 +415,6 @@ object TypeCoercion { if left.dataType != CalendarIntervalType => a.makeCopy(Array(left, Cast(right, DoubleType))) - // For equality between string and timestamp we cast the string to a timestamp - // so that things like rounding of subsecond precision does not affect the comparison. - case p @ Equality(left @ StringType(), right @ TimestampType()) => - p.makeCopy(Array(Cast(left, TimestampType), right)) - case p @ Equality(left @ TimestampType(), right @ StringType()) => - p.makeCopy(Array(left, Cast(right, TimestampType))) - case p @ BinaryComparison(left, right) if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined => val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get @@ -491,10 +484,21 @@ object TypeCoercion { i } - case i @ In(a, b) if b.exists(_.dataType != a.dataType) => - findWiderCommonType(i.children.map(_.dataType)) match { - case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) - case None => i + case i @ In(value, list) if list.exists(_.dataType != value.dataType) => + if (conf.getConf(SQLConf.LEGACY_IN_PREDICATE_FOLLOW_BINARY_COMPARISON_TYPE_COERCION)) { + findWiderCommonType(list.map(_.dataType)) match { + case Some(listType) => + val finalDataType = findCommonTypeForBinaryComparison(value.dataType, listType, conf) + .orElse(findWiderTypeForDecimal(value.dataType, listType)) + .orElse(findTightestCommonType(value.dataType, listType)) + finalDataType.map(t => i.withNewChildren(i.children.map(Cast(_, t)))).getOrElse(i) + case None => i + } + } else { + findWiderCommonType(i.children.map(_.dataType)) match { + case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType))) + case None => i + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 9a524defb281..8b1bc5df392f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2254,6 +2254,13 @@ object SQLConf { .booleanConf .createWithDefault(false) + val LEGACY_IN_PREDICATE_FOLLOW_BINARY_COMPARISON_TYPE_COERCION = + buildConf("spark.sql.legacy.inPredicateFollowBinaryComparisonTypeCoercion") + .internal() + .doc("When set to true, the in predicate follows binary comparison type coercion.") + .booleanConf + .createWithDefault(true) + val TRUNCATE_TABLE_IGNORE_PERMISSION_ACL = buildConf("spark.sql.truncateTable.ignorePermissionAcl.enabled") .internal() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 0d6f9bcedb6a..e573a51e6343 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.analysis -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.analysis.TypeCoercion._ import org.apache.spark.sql.catalyst.dsl.expressions._ @@ -1421,6 +1421,44 @@ class TypeCoercionSuite extends AnalysisTest { In(Cast(Literal("a"), StringType), Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) ) + Seq(true, false).foreach { follow => + val decimalCase = if (follow) { + In(Cast(Decimal(3.13), DoubleType), + Seq(Cast(Literal("1"), DoubleType), Cast(Literal(2), DoubleType))) + } else { + In(Cast(Decimal(3.13), StringType), + Seq(Cast(Literal("1"), StringType), Cast(Literal(2), StringType))) + } + val dateCase = if (follow) { + In(Cast(Literal(Date.valueOf("2017-03-01")), DateType), + Seq(Cast(Literal("2017-03-01"), DateType))) + } else { + In(Cast(Literal(Date.valueOf("2017-03-01")), StringType), + Seq(Cast(Literal("2017-03-01"), StringType))) + } + val timestampCase = if (follow) { + In(Cast(Literal(new Timestamp(0)), TimestampType), + Seq(Cast(Literal("1"), TimestampType), Cast(Literal(2), TimestampType))) + } else { + In(Cast(Literal(new Timestamp(0)), StringType), + Seq(Cast(Literal("1"), StringType), Cast(Literal(2), StringType))) + } + withSQLConf( + SQLConf.LEGACY_IN_PREDICATE_FOLLOW_BINARY_COMPARISON_TYPE_COERCION.key -> s"$follow") { + ruleTest( + inConversion, + In(Literal(Decimal(3.13)), Seq(Literal("1"), Literal(2))), + decimalCase) + ruleTest( + inConversion, + In(Literal(Date.valueOf("2017-03-01")), Seq(Literal("2017-03-01"))), + dateCase) + ruleTest( + inConversion, + In(Literal(new Timestamp(0)), Seq(Literal("1"), Literal(2))), + timestampCase) + } + } } test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out index 21d0a0e0fef4..917e3b0a18b4 100644 --- a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/inConversion.sql.out @@ -69,7 +69,7 @@ true -- !query SELECT cast(1 as tinyint) in (cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS TINYINT) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS TINYINT) AS TINYINT) IN (CAST(CAST(1 AS STRING) AS TINYINT))):boolean> -- !query output true @@ -169,7 +169,7 @@ true -- !query SELECT cast(1 as smallint) in (cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS SMALLINT) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS SMALLINT) AS SMALLINT) IN (CAST(CAST(1 AS STRING) AS SMALLINT))):boolean> -- !query output true @@ -269,7 +269,7 @@ true -- !query SELECT cast(1 as int) in (cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS INT) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS INT) AS INT) IN (CAST(CAST(1 AS STRING) AS INT))):boolean> -- !query output true @@ -369,7 +369,7 @@ true -- !query SELECT cast(1 as bigint) in (cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS BIGINT) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS BIGINT) AS BIGINT) IN (CAST(CAST(1 AS STRING) AS BIGINT))):boolean> -- !query output true @@ -469,9 +469,9 @@ true -- !query SELECT cast(1 as float) in (cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS FLOAT) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS FLOAT) AS FLOAT) IN (CAST(CAST(1 AS STRING) AS FLOAT))):boolean> -- !query output -false +true -- !query @@ -569,9 +569,9 @@ true -- !query SELECT cast(1 as double) in (cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS DOUBLE) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS STRING) AS DOUBLE))):boolean> -- !query output -false +true -- !query @@ -669,7 +669,7 @@ true -- !query SELECT cast(1 as decimal(10, 0)) in (cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS STRING) IN (CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) IN (CAST(CAST(1 AS STRING) AS DOUBLE))):boolean> -- !query output true @@ -713,7 +713,7 @@ cannot resolve '(CAST(1 AS DECIMAL(10,0)) IN (CAST('2017-12-11 09:30:00' AS DATE -- !query SELECT cast(1 as string) in (cast(1 as tinyint)) FROM t -- !query schema -struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS TINYINT) AS STRING))):boolean> +struct<(CAST(CAST(1 AS STRING) AS TINYINT) IN (CAST(CAST(1 AS TINYINT) AS TINYINT))):boolean> -- !query output true @@ -721,7 +721,7 @@ true -- !query SELECT cast(1 as string) in (cast(1 as smallint)) FROM t -- !query schema -struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS SMALLINT) AS STRING))):boolean> +struct<(CAST(CAST(1 AS STRING) AS SMALLINT) IN (CAST(CAST(1 AS SMALLINT) AS SMALLINT))):boolean> -- !query output true @@ -729,7 +729,7 @@ true -- !query SELECT cast(1 as string) in (cast(1 as int)) FROM t -- !query schema -struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS INT) AS STRING))):boolean> +struct<(CAST(CAST(1 AS STRING) AS INT) IN (CAST(CAST(1 AS INT) AS INT))):boolean> -- !query output true @@ -737,7 +737,7 @@ true -- !query SELECT cast(1 as string) in (cast(1 as bigint)) FROM t -- !query schema -struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS BIGINT) AS STRING))):boolean> +struct<(CAST(CAST(1 AS STRING) AS BIGINT) IN (CAST(CAST(1 AS BIGINT) AS BIGINT))):boolean> -- !query output true @@ -745,23 +745,23 @@ true -- !query SELECT cast(1 as string) in (cast(1 as float)) FROM t -- !query schema -struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS FLOAT) AS STRING))):boolean> +struct<(CAST(CAST(1 AS STRING) AS FLOAT) IN (CAST(CAST(1 AS FLOAT) AS FLOAT))):boolean> -- !query output -false +true -- !query SELECT cast(1 as string) in (cast(1 as double)) FROM t -- !query schema -struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS DOUBLE) AS STRING))):boolean> +struct<(CAST(CAST(1 AS STRING) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE))):boolean> -- !query output -false +true -- !query SELECT cast(1 as string) in (cast(1 as decimal(10, 0))) FROM t -- !query schema -struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS STRING))):boolean> +struct<(CAST(CAST(1 AS STRING) AS DOUBLE) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE))):boolean> -- !query output true @@ -777,35 +777,33 @@ true -- !query SELECT cast(1 as string) in (cast('1' as binary)) FROM t -- !query schema -struct<> +struct<(CAST(CAST(1 AS STRING) AS BINARY) IN (CAST(CAST(1 AS BINARY) AS BINARY))):boolean> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST('1' AS BINARY)))' due to data type mismatch: Arguments must be same type but were: string != binary; line 1 pos 25 +true -- !query SELECT cast(1 as string) in (cast(1 as boolean)) FROM t -- !query schema -struct<> +struct<(CAST(CAST(1 AS STRING) AS BOOLEAN) IN (CAST(CAST(1 AS BOOLEAN) AS BOOLEAN))):boolean> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve '(CAST(1 AS STRING) IN (CAST(1 AS BOOLEAN)))' due to data type mismatch: Arguments must be same type but were: string != boolean; line 1 pos 25 +true -- !query SELECT cast(1 as string) in (cast('2017-12-11 09:30:00.0' as timestamp)) FROM t -- !query schema -struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS STRING))):boolean> +struct<(CAST(CAST(1 AS STRING) AS TIMESTAMP) IN (CAST(CAST(2017-12-11 09:30:00.0 AS TIMESTAMP) AS TIMESTAMP))):boolean> -- !query output -false +NULL -- !query SELECT cast(1 as string) in (cast('2017-12-11 09:30:00' as date)) FROM t -- !query schema -struct<(CAST(CAST(1 AS STRING) AS STRING) IN (CAST(CAST(2017-12-11 09:30:00 AS DATE) AS STRING))):boolean> +struct<(CAST(CAST(1 AS STRING) AS DATE) IN (CAST(CAST(2017-12-11 09:30:00 AS DATE) AS DATE))):boolean> -- !query output -false +NULL -- !query @@ -874,10 +872,9 @@ cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS DECIMAL(10,0))))' due to data -- !query SELECT cast('1' as binary) in (cast(1 as string)) FROM t -- !query schema -struct<> +struct<(CAST(CAST(1 AS BINARY) AS BINARY) IN (CAST(CAST(1 AS STRING) AS BINARY))):boolean> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve '(CAST('1' AS BINARY) IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: binary != string; line 1 pos 27 +true -- !query @@ -981,10 +978,9 @@ cannot resolve '(true IN (CAST(1 AS DECIMAL(10,0))))' due to data type mismatch: -- !query SELECT true in (cast(1 as string)) FROM t -- !query schema -struct<> +struct<(CAST(true AS BOOLEAN) IN (CAST(CAST(1 AS STRING) AS BOOLEAN))):boolean> -- !query output -org.apache.spark.sql.AnalysisException -cannot resolve '(true IN (CAST(1 AS STRING)))' due to data type mismatch: Arguments must be same type but were: boolean != string; line 1 pos 12 +true -- !query @@ -1088,9 +1084,9 @@ cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST(2 AS DECIMA -- !query SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast(2 as string)) FROM t -- !query schema -struct<(CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS STRING) IN (CAST(CAST(2 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS TIMESTAMP) IN (CAST(CAST(2 AS STRING) AS TIMESTAMP))):boolean> -- !query output -false +NULL -- !query @@ -1193,9 +1189,9 @@ cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST(2 AS DECIMAL(10,0) -- !query SELECT cast('2017-12-12 09:30:00' as date) in (cast(2 as string)) FROM t -- !query schema -struct<(CAST(CAST(2017-12-12 09:30:00 AS DATE) AS STRING) IN (CAST(CAST(2 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(2017-12-12 09:30:00 AS DATE) AS DATE) IN (CAST(CAST(2 AS STRING) AS DATE))):boolean> -- !query output -false +NULL -- !query @@ -1291,7 +1287,7 @@ true -- !query SELECT cast(1 as tinyint) in (cast(1 as tinyint), cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS TINYINT) AS STRING) IN (CAST(CAST(1 AS TINYINT) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS TINYINT) AS TINYINT) IN (CAST(CAST(1 AS TINYINT) AS TINYINT), CAST(CAST(1 AS STRING) AS TINYINT))):boolean> -- !query output true @@ -1391,7 +1387,7 @@ true -- !query SELECT cast(1 as smallint) in (cast(1 as smallint), cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS SMALLINT) AS STRING) IN (CAST(CAST(1 AS SMALLINT) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS SMALLINT) AS SMALLINT) IN (CAST(CAST(1 AS SMALLINT) AS SMALLINT), CAST(CAST(1 AS STRING) AS SMALLINT))):boolean> -- !query output true @@ -1491,7 +1487,7 @@ true -- !query SELECT cast(1 as int) in (cast(1 as int), cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS INT) AS STRING) IN (CAST(CAST(1 AS INT) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS INT) AS INT) IN (CAST(CAST(1 AS INT) AS INT), CAST(CAST(1 AS STRING) AS INT))):boolean> -- !query output true @@ -1591,7 +1587,7 @@ true -- !query SELECT cast(1 as bigint) in (cast(1 as bigint), cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS BIGINT) AS STRING) IN (CAST(CAST(1 AS BIGINT) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS BIGINT) AS BIGINT) IN (CAST(CAST(1 AS BIGINT) AS BIGINT), CAST(CAST(1 AS STRING) AS BIGINT))):boolean> -- !query output true @@ -1691,7 +1687,7 @@ true -- !query SELECT cast(1 as float) in (cast(1 as float), cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS FLOAT) AS STRING) IN (CAST(CAST(1 AS FLOAT) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS FLOAT) AS FLOAT) IN (CAST(CAST(1 AS FLOAT) AS FLOAT), CAST(CAST(1 AS STRING) AS FLOAT))):boolean> -- !query output true @@ -1791,7 +1787,7 @@ true -- !query SELECT cast(1 as double) in (cast(1 as double), cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS DOUBLE) AS STRING) IN (CAST(CAST(1 AS DOUBLE) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS DOUBLE) AS DOUBLE) IN (CAST(CAST(1 AS DOUBLE) AS DOUBLE), CAST(CAST(1 AS STRING) AS DOUBLE))):boolean> -- !query output true @@ -1891,7 +1887,7 @@ true -- !query SELECT cast(1 as decimal(10, 0)) in (cast(1 as decimal(10, 0)), cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS STRING) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE) IN (CAST(CAST(1 AS DECIMAL(10,0)) AS DOUBLE), CAST(CAST(1 AS STRING) AS DOUBLE))):boolean> -- !query output true @@ -2310,7 +2306,7 @@ cannot resolve '(CAST('2017-12-12 09:30:00.0' AS TIMESTAMP) IN (CAST('2017-12-12 -- !query SELECT cast('2017-12-12 09:30:00.0' as timestamp) in (cast('2017-12-12 09:30:00.0' as timestamp), cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS STRING) IN (CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS TIMESTAMP) IN (CAST(CAST(2017-12-12 09:30:00.0 AS TIMESTAMP) AS TIMESTAMP), CAST(CAST(1 AS STRING) AS TIMESTAMP))):boolean> -- !query output true @@ -2415,7 +2411,7 @@ cannot resolve '(CAST('2017-12-12 09:30:00' AS DATE) IN (CAST('2017-12-12 09:30: -- !query SELECT cast('2017-12-12 09:30:00' as date) in (cast('2017-12-12 09:30:00' as date), cast(1 as string)) FROM t -- !query schema -struct<(CAST(CAST(2017-12-12 09:30:00 AS DATE) AS STRING) IN (CAST(CAST(2017-12-12 09:30:00 AS DATE) AS STRING), CAST(CAST(1 AS STRING) AS STRING))):boolean> +struct<(CAST(CAST(2017-12-12 09:30:00 AS DATE) AS DATE) IN (CAST(CAST(2017-12-12 09:30:00 AS DATE) AS DATE), CAST(CAST(1 AS STRING) AS DATE))):boolean> -- !query output true