diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index e517376bc5fc0..fe6307b5bbe86 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -530,7 +530,7 @@ case class ScalaAggregator[IN, BUF, OUT]( def eval(buffer: BUF): Any = { val row = outputSerializer(agg.finish(buffer)) - if (outputEncoder.isSerializedAsStruct) row else row.get(0, dataType) + if (outputEncoder.isSerializedAsStructForTopLevel) row else row.get(0, dataType) } private[this] lazy val bufferRow = new UnsafeRow(bufferEncoder.namedExpressions.length) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala index 0bd6b1403d39c..31d0452c70617 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala @@ -60,6 +60,22 @@ object LongProductSumAgg extends Aggregator[(jlLong, jlLong), Long, jlLong] { def outputEncoder: Encoder[jlLong] = Encoders.LONG } +final case class Reduce[T: Encoder](r: (T, T) => T)(implicit i: Encoder[Option[T]]) + extends Aggregator[T, Option[T], T] { + def zero: Option[T] = None + def reduce(b: Option[T], a: T): Option[T] = Some(b.fold(a)(r(_, a))) + def merge(b1: Option[T], b2: Option[T]): Option[T] = + (b1, b2) match { + case (Some(a), Some(b)) => Some(r(a, b)) + case (Some(a), None) => Some(a) + case (None, Some(b)) => Some(b) + case (None, None) => None + } + def finish(reduction: Option[T]): T = reduction.get + def bufferEncoder: Encoder[Option[T]] = implicitly + def outputEncoder: Encoder[T] = implicitly +} + @SQLUserDefinedType(udt = classOf[CountSerDeUDT]) case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Int) @@ -180,6 +196,9 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi val data4 = Seq[Boolean](true, false, true).toDF("boolvalues") data4.write.saveAsTable("agg4") + val data5 = Seq[(Int, (Int, Int))]((1, (2, 3))).toDF("key", "value") + data5.write.saveAsTable("agg5") + val emptyDF = spark.createDataFrame( sparkContext.emptyRDD[Row], StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil)) @@ -190,6 +209,9 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg)) spark.udf.register("longProductSum", udaf(LongProductSumAgg)) spark.udf.register("arraysum", udaf(ArrayDataAgg)) + spark.udf.register("reduceOptionPair", udaf(Reduce[Option[(Int, Int)]]( + (opt1, opt2) => + opt1.zip(opt2).map { case ((a1, b1), (a2, b2)) => (a1 + a2, b1 + b2) }.headOption))) } override def afterAll(): Unit = { @@ -371,6 +393,12 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi Row(Seq(12.0, 15.0, 18.0)) :: Nil) } + test("SPARK-52023: Returning Option[Product] from udaf") { + checkAnswer( + spark.sql("SELECT reduceOptionPair(value) FROM agg5 GROUP BY key"), + Row(Row(2, 3)) :: Nil) + } + test("verify aggregator ser/de behavior") { val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1") val agg = udaf(CountSerDeAgg)