Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,6 @@ object TypeCoercion {
if left.dataType != CalendarIntervalType =>
a.makeCopy(Array(left, Cast(right, DoubleType)))

// For equality between string and timestamp we cast the string to a timestamp
// so that things like rounding of subsecond precision does not affect the comparison.
case p @ Equality(left @ StringType(), right @ TimestampType()) =>
p.makeCopy(Array(Cast(left, TimestampType), right))
case p @ Equality(left @ TimestampType(), right @ StringType()) =>
p.makeCopy(Array(left, Cast(right, TimestampType)))

case p @ BinaryComparison(left, right)
if findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).isDefined =>
val commonType = findCommonTypeForBinaryComparison(left.dataType, right.dataType, conf).get
Expand Down Expand Up @@ -491,10 +484,21 @@ object TypeCoercion {
i
}

case i @ In(a, b) if b.exists(_.dataType != a.dataType) =>
findWiderCommonType(i.children.map(_.dataType)) match {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
case None => i
case i @ In(value, list) if list.exists(_.dataType != value.dataType) =>
if (conf.getConf(SQLConf.LEGACY_IN_PREDICATE_FOLLOW_BINARY_COMPARISON_TYPE_COERCION)) {
findWiderCommonType(list.map(_.dataType)) match {
case Some(listType) =>
val finalDataType = findCommonTypeForBinaryComparison(value.dataType, listType, conf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wangyum, the behaviours between decimals and strings look good. But what about other types affected here?

If we think about interpreting IN as = with OR, we should think about other rules applied to equality comparison, for example:

      // For equality between string and timestamp we cast the string to a timestamp
      // so that things like rounding of subsecond precision does not affect the comparison.
      case p @ Equality(left @ StringType(), right @ TimestampType()) =>
        p.makeCopy(Array(Cast(left, TimestampType), right))
      case p @ Equality(left @ TimestampType(), right @ StringType()) =>
        p.makeCopy(Array(left, Cast(right, TimestampType)))

What do you think about fixing this issue completely rather than fixing cases one by one? I didn't check ANSI or other DBMSs yet but I know IN is able to be rewritten to = with OR. Considering that, I suspect the type coercion will be similar too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove TypeCoercion.scala#L418-L423 because we have added the same logic to findCommonTypeForBinaryComparison.

.orElse(findWiderTypeForDecimal(value.dataType, listType))
.orElse(findTightestCommonType(value.dataType, listType))
finalDataType.map(t => i.withNewChildren(i.children.map(Cast(_, t)))).getOrElse(i)
case None => i
}
} else {
findWiderCommonType(i.children.map(_.dataType)) match {
case Some(finalDataType) => i.withNewChildren(i.children.map(Cast(_, finalDataType)))
case None => i
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2254,6 +2254,13 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_IN_PREDICATE_FOLLOW_BINARY_COMPARISON_TYPE_COERCION =
buildConf("spark.sql.legacy.inPredicateFollowBinaryComparisonTypeCoercion")
.internal()
.doc("When set to true, the in predicate follows binary comparison type coercion.")
.booleanConf
.createWithDefault(true)

val TRUNCATE_TABLE_IGNORE_PERMISSION_ACL =
buildConf("spark.sql.truncateTable.ignorePermissionAcl.enabled")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.analysis

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.analysis.TypeCoercion._
import org.apache.spark.sql.catalyst.dsl.expressions._
Expand Down Expand Up @@ -1421,6 +1421,44 @@ class TypeCoercionSuite extends AnalysisTest {
In(Cast(Literal("a"), StringType),
Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType)))
)
Seq(true, false).foreach { follow =>
val decimalCase = if (follow) {
In(Cast(Decimal(3.13), DoubleType),
Seq(Cast(Literal("1"), DoubleType), Cast(Literal(2), DoubleType)))
} else {
In(Cast(Decimal(3.13), StringType),
Seq(Cast(Literal("1"), StringType), Cast(Literal(2), StringType)))
}
val dateCase = if (follow) {
In(Cast(Literal(Date.valueOf("2017-03-01")), DateType),
Seq(Cast(Literal("2017-03-01"), DateType)))
} else {
In(Cast(Literal(Date.valueOf("2017-03-01")), StringType),
Seq(Cast(Literal("2017-03-01"), StringType)))
}
val timestampCase = if (follow) {
In(Cast(Literal(new Timestamp(0)), TimestampType),
Seq(Cast(Literal("1"), TimestampType), Cast(Literal(2), TimestampType)))
} else {
In(Cast(Literal(new Timestamp(0)), StringType),
Seq(Cast(Literal("1"), StringType), Cast(Literal(2), StringType)))
}
withSQLConf(
SQLConf.LEGACY_IN_PREDICATE_FOLLOW_BINARY_COMPARISON_TYPE_COERCION.key -> s"$follow") {
ruleTest(
inConversion,
In(Literal(Decimal(3.13)), Seq(Literal("1"), Literal(2))),
decimalCase)
ruleTest(
inConversion,
In(Literal(Date.valueOf("2017-03-01")), Seq(Literal("2017-03-01"))),
dateCase)
ruleTest(
inConversion,
In(Literal(new Timestamp(0)), Seq(Literal("1"), Literal(2))),
timestampCase)
}
}
}

test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " +
Expand Down
Loading