Skip to content

Commit f9a50e2

Browse files
committed
Update code
1 parent e742552 commit f9a50e2

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
*/
1717
package org.apache.spark.sql.internal
1818

19+
import java.io.Serializable
20+
1921
import scala.reflect.ClassTag
2022
import scala.reflect.runtime.universe
2123

@@ -33,13 +35,14 @@ import org.apache.spark.sql.catalyst.rules.Rule
3335
import org.apache.spark.sql.connector.catalog.CatalogManager
3436
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
3537
import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser}
36-
import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, ScalaAggregator, ScalaUDAF}
38+
import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, ScalaUDAF}
3739
import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
3840
import org.apache.spark.sql.execution.command.CommandCheck
3941
import org.apache.spark.sql.execution.datasources._
4042
import org.apache.spark.sql.execution.datasources.v2.{TableCapabilityCheck, V2SessionCatalog}
4143
import org.apache.spark.sql.execution.streaming.ResolveWriteToStream
42-
import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction}
44+
import org.apache.spark.sql.expressions.{Aggregator, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
45+
import org.apache.spark.sql.functions.udaf
4346
import org.apache.spark.sql.streaming.StreamingQueryManager
4447
import org.apache.spark.sql.util.ExecutionListenerManager
4548

@@ -423,25 +426,25 @@ class SparkUDFExpressionBuilder extends FunctionExpressionBuilder {
423426
clazz.getCanonicalName)
424427
}
425428
val aggregator =
426-
noParameterConstructor.get.newInstance().asInstanceOf[Aggregator[Any, Any, Any]]
429+
noParameterConstructor.get.newInstance().asInstanceOf[Aggregator[Serializable, Any, Any]]
427430

428431
// Construct the input encoder
429432
val mirror = universe.runtimeMirror(clazz.getClassLoader)
430433
val classType = mirror.classSymbol(clazz)
431-
val baseClassType = universe.typeOf[Aggregator[_, _, _]].typeSymbol.asClass
434+
val baseClassType = universe.typeOf[Aggregator[Serializable, Any, Any]].typeSymbol.asClass
432435
val baseType = universe.internal.thisType(classType).baseType(baseClassType)
433436
val tpe = baseType.typeArgs.head
434-
val cls = mirror.runtimeClass(tpe)
435437
val serializer = ScalaReflection.serializerForType(tpe)
436438
val deserializer = ScalaReflection.deserializerForType(tpe)
437-
val inputEncoder = new ExpressionEncoder[Any](serializer, deserializer, ClassTag(cls))
439+
val cls = mirror.runtimeClass(tpe)
440+
val inputEncoder =
441+
new ExpressionEncoder[Serializable](serializer, deserializer, ClassTag(cls))
438442

439-
val expr = ScalaAggregator[Any, Any, Any](
440-
input,
441-
aggregator,
442-
inputEncoder,
443-
aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[Any]],
444-
aggregatorName = Some(name))
443+
val udf: UserDefinedFunction = udaf[Serializable, Any, Any](aggregator, inputEncoder)
444+
assert(udf.isInstanceOf[UserDefinedAggregator[_, _, _]])
445+
val udfAgg: UserDefinedAggregator[_, _, _] = udf.asInstanceOf[UserDefinedAggregator[_, _, _]]
446+
447+
val expr = udfAgg.scalaAggregator(input)
445448
// Check input argument size
446449
if (expr.inputTypes.size != input.size) {
447450
throw QueryCompilationErrors.invalidFunctionArgumentsError(

0 commit comments

Comments
 (0)