1616 */
1717package org .apache .spark .sql .internal
1818
19+ import java .io .Serializable
20+
1921import scala .reflect .ClassTag
2022import scala .reflect .runtime .universe
2123
@@ -33,13 +35,14 @@ import org.apache.spark.sql.catalyst.rules.Rule
3335import org .apache .spark .sql .connector .catalog .CatalogManager
3436import org .apache .spark .sql .errors .{QueryCompilationErrors , QueryExecutionErrors }
3537import org .apache .spark .sql .execution .{ColumnarRule , CommandExecutionMode , QueryExecution , SparkOptimizer , SparkPlan , SparkPlanner , SparkSqlParser }
36- import org .apache .spark .sql .execution .aggregate .{ResolveEncodersInScalaAgg , ScalaAggregator , ScalaUDAF }
38+ import org .apache .spark .sql .execution .aggregate .{ResolveEncodersInScalaAgg , ScalaUDAF }
3739import org .apache .spark .sql .execution .analysis .DetectAmbiguousSelfJoin
3840import org .apache .spark .sql .execution .command .CommandCheck
3941import org .apache .spark .sql .execution .datasources ._
4042import org .apache .spark .sql .execution .datasources .v2 .{TableCapabilityCheck , V2SessionCatalog }
4143import org .apache .spark .sql .execution .streaming .ResolveWriteToStream
42- import org .apache .spark .sql .expressions .{Aggregator , UserDefinedAggregateFunction }
44+ import org .apache .spark .sql .expressions .{Aggregator , UserDefinedAggregateFunction , UserDefinedAggregator , UserDefinedFunction }
45+ import org .apache .spark .sql .functions .udaf
4346import org .apache .spark .sql .streaming .StreamingQueryManager
4447import org .apache .spark .sql .util .ExecutionListenerManager
4548
@@ -423,25 +426,25 @@ class SparkUDFExpressionBuilder extends FunctionExpressionBuilder {
423426 clazz.getCanonicalName)
424427 }
425428 val aggregator =
426- noParameterConstructor.get.newInstance().asInstanceOf [Aggregator [Any , Any , Any ]]
429+ noParameterConstructor.get.newInstance().asInstanceOf [Aggregator [Serializable , Any , Any ]]
427430
428431 // Construct the input encoder
429432 val mirror = universe.runtimeMirror(clazz.getClassLoader)
430433 val classType = mirror.classSymbol(clazz)
431- val baseClassType = universe.typeOf[Aggregator [_, _, _ ]].typeSymbol.asClass
434+ val baseClassType = universe.typeOf[Aggregator [Serializable , Any , Any ]].typeSymbol.asClass
432435 val baseType = universe.internal.thisType(classType).baseType(baseClassType)
433436 val tpe = baseType.typeArgs.head
434- val cls = mirror.runtimeClass(tpe)
435437 val serializer = ScalaReflection .serializerForType(tpe)
436438 val deserializer = ScalaReflection .deserializerForType(tpe)
437- val inputEncoder = new ExpressionEncoder [Any ](serializer, deserializer, ClassTag (cls))
439+ val cls = mirror.runtimeClass(tpe)
440+ val inputEncoder =
441+ new ExpressionEncoder [Serializable ](serializer, deserializer, ClassTag (cls))
438442
439- val expr = ScalaAggregator [Any , Any , Any ](
440- input,
441- aggregator,
442- inputEncoder,
443- aggregator.bufferEncoder.asInstanceOf [ExpressionEncoder [Any ]],
444- aggregatorName = Some (name))
443+ val udf : UserDefinedFunction = udaf[Serializable , Any , Any ](aggregator, inputEncoder)
444+ assert(udf.isInstanceOf [UserDefinedAggregator [_, _, _]])
445+ val udfAgg : UserDefinedAggregator [_, _, _] = udf.asInstanceOf [UserDefinedAggregator [_, _, _]]
446+
447+ val expr = udfAgg.scalaAggregator(input)
445448 // Check input argument size
446449 if (expr.inputTypes.size != input.size) {
447450 throw QueryCompilationErrors .invalidFunctionArgumentsError(
0 commit comments