Skip to content

Commit 629ec84

Browse files
committed
[SPARK-25415] ArrayPosition function may return incorrect result when right expression is implicitly down casted.
1 parent 3030b82 commit 629ec84

File tree

2 files changed

+54
-9
lines changed

2 files changed

+54
-9
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,18 +2062,23 @@ case class ArrayPosition(left: Expression, right: Expression)
20622062
override def dataType: DataType = LongType
20632063

20642064
override def inputTypes: Seq[AbstractDataType] = {
2065-
val elementType = left.dataType match {
2066-
case t: ArrayType => t.elementType
2067-
case _ => AnyDataType
2065+
(left.dataType, right.dataType) match {
2066+
case (ArrayType(e1, hasNull), e2) =>
2067+
TypeCoercion.findTightestCommonType(e1, e2) match {
2068+
case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
2069+
case _ => Seq.empty
2070+
}
2071+
case _ => Seq.empty
20682072
}
2069-
Seq(ArrayType, elementType)
20702073
}
20712074

20722075
override def checkInputDataTypes(): TypeCheckResult = {
2073-
super.checkInputDataTypes() match {
2074-
case f: TypeCheckResult.TypeCheckFailure => f
2075-
case TypeCheckResult.TypeCheckSuccess =>
2076+
(left.dataType, right.dataType) match {
2077+
case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
20762078
TypeUtils.checkForOrderingExpr(right.dataType, s"function $prettyName")
2079+
case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
2080+
s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " +
2081+
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
20772082
}
20782083
}
20792084

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,31 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
10451045
Seq(Row(null), Row(null))
10461046
)
10471047

1048+
checkAnswer(
1049+
df.selectExpr("array_position(array(1), 1.23D)"),
1050+
Seq(Row(0L), Row(0L))
1051+
)
1052+
1053+
checkAnswer(
1054+
df.selectExpr("array_position(array(1), 1.0D)"),
1055+
Seq(Row(1L), Row(1L))
1056+
)
1057+
1058+
checkAnswer(
1059+
df.selectExpr("array_position(array(1.23D), 1)"),
1060+
Seq(Row(0L), Row(0L))
1061+
)
1062+
1063+
checkAnswer(
1064+
df.selectExpr("array_position(array(array(1)), array(1.0D))"),
1065+
Seq(Row(1L), Row(1L))
1066+
)
1067+
1068+
checkAnswer(
1069+
df.selectExpr("array_position(array(array(1)), array(1.23D))"),
1070+
Seq(Row(0L), Row(0L))
1071+
)
1072+
10481073
checkAnswer(
10491074
df.selectExpr("array_position(array(array(1), null)[0], 1)"),
10501075
Seq(Row(1L), Row(1L))
@@ -1054,10 +1079,25 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
10541079
Seq(Row(1L), Row(1L))
10551080
)
10561081

1057-
val e = intercept[AnalysisException] {
1082+
val e1 = intercept[AnalysisException] {
10581083
Seq(("a string element", "a")).toDF().selectExpr("array_position(_1, _2)")
10591084
}
1060-
assert(e.message.contains("argument 1 requires array type, however, '`_1`' is of string type"))
1085+
val errorMsg1 =
1086+
s"""
1087+
|Input to function array_position should have been array followed by a
1088+
|value with same element type, but it's [string, string].
1089+
""".stripMargin.replace("\n", " ").trim()
1090+
assert(e1.message.contains(errorMsg1))
1091+
1092+
val e2 = intercept[AnalysisException] {
1093+
df.selectExpr("array_position(array(1), '1')")
1094+
}
1095+
val errorMsg2 =
1096+
s"""
1097+
|Input to function array_position should have been array followed by a
1098+
|value with same element type, but it's [array<int>, string].
1099+
""".stripMargin.replace("\n", " ").trim()
1100+
assert(e2.message.contains(errorMsg2))
10611101
}
10621102

10631103
test("element_at function") {

0 commit comments

Comments
 (0)