Skip to content

Commit e6505d1

Browse files
add the unit test for the map column in group by
1 parent a989080 commit e6505d1

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ class AnalysisErrorSuite extends AnalysisTest {
588588
FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
589589
DateType, TimestampType,
590590
ArrayType(IntegerType),
591+
MapType(StringType, LongType),
591592
new StructType()
592593
.add("f1", FloatType, nullable = true)
593594
.add("f2", StringType, nullable = true),
@@ -600,7 +601,6 @@ class AnalysisErrorSuite extends AnalysisTest {
600601
}
601602

602603
val unsupportedDataTypes = Seq(
603-
MapType(StringType, LongType),
604604
new StructType()
605605
.add("f1", FloatType, nullable = true)
606606
.add("f2", MapType(StringType, LongType), nullable = true),

sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1427,6 +1427,33 @@ class DataFrameAggregateSuite extends QueryTest
14271427
assert (df.schema == expectedSchema)
14281428
checkAnswer(df, Seq(Row(LocalDateTime.parse(ts1), 2), Row(LocalDateTime.parse(ts2), 1)))
14291429
}
1430+
1431+
test("SPARK-36452: Support Map Type column in group by") {
1432+
var df = Seq((1, Map(1 -> 2)), (2, Map(1 -> 2))).toDF("id", "mapInfo")
1433+
// group by map column
1434+
checkAnswer(df.groupBy("mapInfo").count(), Seq(Row(Map[Any, Any](1 -> 2), 2)))
1435+
// group by map column and other column
1436+
checkAnswer(df.groupBy("id", "mapInfo").count(),
1437+
Seq(Row(1, Map[Any, Any](1 -> 2), 1), Row(2, Map[Any, Any](1 -> 2), 1)))
1438+
checkAnswer(df.groupBy("mapInfo").agg(avg("id")),
1439+
Seq(Row(Map[Any, Any](1 -> 2), 1.5)))
1440+
// Does not support if the map type if present in the aggregated expression
1441+
var error = intercept[IllegalStateException] {
1442+
df.groupBy("mapInfo").agg(max(map_keys(col("mapinfo")))).collect
1443+
}
1444+
assert(error.getMessage.contains("grouping/join/window partition keys cannot be map type."))
1445+
// Does not support if the map type with float/double keys or value
1446+
df = Seq((1, Map(1 -> 2.0)), (2, Map(1 -> 2.0))).toDF("id", "mapInfo")
1447+
error = intercept[IllegalStateException] {
1448+
df.groupBy("mapInfo").agg(max(map_keys(col("mapinfo")))).collect
1449+
}
1450+
assert(error.getMessage.contains("grouping/join/window partition keys cannot be map type."))
1451+
df = Seq((1, Map(1.1 -> 2.0)), (2, Map(1.1 -> 2.0))).toDF("id", "mapInfo")
1452+
error = intercept[IllegalStateException] {
1453+
df.groupBy("mapInfo").agg(max(map_keys(col("mapinfo")))).collect
1454+
}
1455+
assert(error.getMessage.contains("grouping/join/window partition keys cannot be map type."))
1456+
}
14301457
}
14311458

14321459
case class B(c: Option[Double])

0 commit comments

Comments
 (0)