diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index 380289ba5feef..51bfd2b986220 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -59,10 +59,10 @@ case class Mode( override def update( buffer: OpenHashMap[AnyRef, Long], input: InternalRow): OpenHashMap[AnyRef, Long] = { - val key = child.eval(input).asInstanceOf[AnyRef] + val key = child.eval(input) if (key != null) { - buffer.changeValue(key, 1L, _ + 1L) + buffer.changeValue(InternalRow.copyValue(key).asInstanceOf[AnyRef], 1L, _ + 1L) } buffer } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index a22abd505ca00..e9daa825dd46c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -418,4 +418,18 @@ class DatasetAggregatorSuite extends QueryTest with SharedSparkSession { assert(err.contains("cannot be passed in untyped `select` API. " + "Use the typed `Dataset.select` API instead.")) } + + test("SPARK-40906: Mode should copy keys before inserting into Map") { + val df = spark.sparkContext.parallelize(Seq.empty[Int], 4) + .mapPartitionsWithIndex { (idx, iter) => + if (idx == 3) { + Iterator("3", "3", "3", "3", "4") + } else { + Iterator("0", "1", "2", "3", "4") + } + }.toDF("a") + + val agg = df.select(mode(col("a"))).as[String] + checkDataset(agg, "3") + } }