Skip to content

Commit 297f406

Browse files
amanomercloud-fan
authored andcommitted
[SPARK-29600][SQL] ArrayContains function may return incorrect result for DecimalType
### What changes were proposed in this pull request? Use `TypeCoercion.findWiderTypeForTwo()` instead of `TypeCoercion.findTightestCommonType()` while preprocessing `inputTypes` in `ArrayContains`. ### Why are the changes needed? `TypeCoercion.findWiderTypeForTwo()` also handles cases for DecimalType. ### Does this PR introduce any user-facing change? No ### How was this patch tested? Test cases to be added. Closes #26811 from amanomer/29600. Authored-by: Aman Omer <amanomer1996@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent fac6b9b commit 297f406

File tree

3 files changed

+20
-3
lines changed

3 files changed

+20
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ object TypeCoercion {
243243
* string. If the wider decimal type exceeds system limitation, this rule will truncate
244244
* the decimal type before return it.
245245
*/
246-
private[analysis] def findWiderTypeWithoutStringPromotionForTwo(
246+
private[catalyst] def findWiderTypeWithoutStringPromotionForTwo(
247247
t1: DataType,
248248
t2: DataType): Option[DataType] = {
249249
findTightestCommonType(t1, t2)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,7 @@ case class ArrayContains(left: Expression, right: Expression)
10811081
(left.dataType, right.dataType) match {
10821082
case (_, NullType) => Seq.empty
10831083
case (ArrayType(e1, hasNull), e2) =>
1084-
TypeCoercion.findTightestCommonType(e1, e2) match {
1084+
TypeCoercion.findWiderTypeWithoutStringPromotionForTwo(e1, e2) match {
10851085
case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
10861086
case _ => Seq.empty
10871087
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -850,7 +850,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
850850
val errorMsg1 =
851851
s"""
852852
|Input to function array_contains should have been array followed by a
853-
|value with same element type, but it's [array<int>, decimal(29,29)].
853+
|value with same element type, but it's [array<int>, decimal(38,29)].
854854
""".stripMargin.replace("\n", " ").trim()
855855
assert(e1.message.contains(errorMsg1))
856856

@@ -865,6 +865,23 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
865865
assert(e2.message.contains(errorMsg2))
866866
}
867867

868+
test("SPARK-29600: ArrayContains function may return incorrect result for DecimalType") {
869+
checkAnswer(
870+
sql("select array_contains(array(1.10), 1.1)"),
871+
Seq(Row(true))
872+
)
873+
874+
checkAnswer(
875+
sql("SELECT array_contains(array(1.1), 1.10)"),
876+
Seq(Row(true))
877+
)
878+
879+
checkAnswer(
880+
sql("SELECT array_contains(array(1.11), 1.1)"),
881+
Seq(Row(false))
882+
)
883+
}
884+
868885
test("arrays_overlap function") {
869886
val df = Seq(
870887
(Seq[Option[Int]](Some(1), Some(2)), Seq[Option[Int]](Some(-1), Some(10))),

0 commit comments

Comments
 (0)