Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -530,7 +530,7 @@ case class ScalaAggregator[IN, BUF, OUT](

def eval(buffer: BUF): Any = {
val row = outputSerializer(agg.finish(buffer))
if (outputEncoder.isSerializedAsStruct) row else row.get(0, dataType)
if (outputEncoder.isSerializedAsStructForTopLevel) row else row.get(0, dataType)
}

private[this] lazy val bufferRow = new UnsafeRow(bufferEncoder.namedExpressions.length)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,22 @@ object LongProductSumAgg extends Aggregator[(jlLong, jlLong), Long, jlLong] {
def outputEncoder: Encoder[jlLong] = Encoders.LONG
}

final case class Reduce[T: Encoder](r: (T, T) => T)(implicit i: Encoder[Option[T]])
extends Aggregator[T, Option[T], T] {
def zero: Option[T] = None
def reduce(b: Option[T], a: T): Option[T] = Some(b.fold(a)(r(_, a)))
def merge(b1: Option[T], b2: Option[T]): Option[T] =
(b1, b2) match {
case (Some(a), Some(b)) => Some(r(a, b))
case (Some(a), None) => Some(a)
case (None, Some(b)) => Some(b)
case (None, None) => None
}
def finish(reduction: Option[T]): T = reduction.get
def bufferEncoder: Encoder[Option[T]] = implicitly
def outputEncoder: Encoder[T] = implicitly
}

@SQLUserDefinedType(udt = classOf[CountSerDeUDT])
case class CountSerDeSQL(nSer: Int, nDeSer: Int, sum: Int)

Expand Down Expand Up @@ -180,6 +196,9 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi
val data4 = Seq[Boolean](true, false, true).toDF("boolvalues")
data4.write.saveAsTable("agg4")

val data5 = Seq[(Int, (Int, Int))]((1, (2, 3))).toDF("key", "value")
data5.write.saveAsTable("agg5")

val emptyDF = spark.createDataFrame(
sparkContext.emptyRDD[Row],
StructType(StructField("key", StringType) :: StructField("value", IntegerType) :: Nil))
Expand All @@ -190,6 +209,9 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi
spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg))
spark.udf.register("longProductSum", udaf(LongProductSumAgg))
spark.udf.register("arraysum", udaf(ArrayDataAgg))
spark.udf.register("reduceOptionPair", udaf(Reduce[Option[(Int, Int)]](
(opt1, opt2) =>
opt1.zip(opt2).map { case ((a1, b1), (a2, b2)) => (a1 + a2, b1 + b2) }.headOption)))
}

override def afterAll(): Unit = {
Expand Down Expand Up @@ -371,6 +393,12 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi
Row(Seq(12.0, 15.0, 18.0)) :: Nil)
}

test("SPARK-52023: Returning Option[Product] from udaf") {
checkAnswer(
spark.sql("SELECT reduceOptionPair(value) FROM agg5 GROUP BY key"),
Row(Row(2, 3)) :: Nil)
}

test("verify aggregator ser/de behavior") {
val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1")
val agg = udaf(CountSerDeAgg)
Expand Down