From aa472a5ef144ccb3e22658c4c2f93146527a9265 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 21 Oct 2021 00:38:44 +0800 Subject: [PATCH] move UDAF expression building from sql/catalyst to sql/core --- .../catalog/FunctionExpressionBuilder.scala | 31 ++++++ .../sql/catalyst/catalog/SessionCatalog.scala | 42 +------- .../internal/BaseSessionStateBuilder.scala | 29 +++++- .../spark/sql/hive/HiveSessionCatalog.scala | 96 ++----------------- .../sql/hive/HiveSessionStateBuilder.scala | 81 +++++++++++++++- 5 files changed, 149 insertions(+), 130 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/FunctionExpressionBuilder.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/FunctionExpressionBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/FunctionExpressionBuilder.scala new file mode 100644 index 000000000000..bf3d790b86c0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/FunctionExpressionBuilder.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.catalyst.expressions.Expression + +// A builder to create `Expression` from function information. +trait FunctionExpressionBuilder { + def makeExpression(name: String, clazz: Class[_], input: Seq[Expression]): Expression +} + +object DummyFunctionExpressionBuilder extends FunctionExpressionBuilder { + override def makeExpression(name: String, clazz: Class[_], input: Seq[Expression]): Expression = { + throw new UnsupportedOperationException + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 8bba6bd08e01..c3cc78e7a96b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, ExpressionInfo, ImplicitCastInputTypes, UpCast} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, ExpressionInfo, UpCast} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils} @@ -66,6 +66,7 @@ class SessionCatalog( hadoopConf: Configuration, parser: ParserInterface, functionResourceLoader: FunctionResourceLoader, + functionExpressionBuilder: FunctionExpressionBuilder, cacheSize: Int = SQLConf.get.tableRelationCacheSize, cacheTTL: Long = SQLConf.get.metadataCacheTTL) extends SQLConfHelper with Logging { import SessionCatalog._ @@ -85,6 +86,7 @@ class SessionCatalog( new Configuration(), new CatalystSqlParser(), DummyFunctionResourceLoader, + DummyFunctionExpressionBuilder, conf.tableRelationCacheSize, conf.metadataCacheTTL) } @@ -1437,43 +1439,7 @@ class SessionCatalog( */ private def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { val clazz = Utils.classForName(functionClassName) - (input: Seq[Expression]) => makeFunctionExpression(name, clazz, input) - } - - /** - * Constructs a [[Expression]] based on the provided class that represents a function. - * - * This performs reflection to decide what type of [[Expression]] to return in the builder. - */ - protected def makeFunctionExpression( - name: String, - clazz: Class[_], - input: Seq[Expression]): Expression = { - // Unfortunately we need to use reflection here because UserDefinedAggregateFunction - // and ScalaUDAF are defined in sql/core module. - val clsForUDAF = - Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction") - if (clsForUDAF.isAssignableFrom(clazz)) { - val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") - val e = cls.getConstructor( - classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int], classOf[Option[String]]) - .newInstance( - input, - clazz.getConstructor().newInstance().asInstanceOf[Object], - Int.box(1), - Int.box(1), - Some(name)) - .asInstanceOf[ImplicitCastInputTypes] - - // Check input argument size - if (e.inputTypes.size != input.size) { - throw QueryCompilationErrors.invalidFunctionArgumentsError( - name, e.inputTypes.size.toString, input.size) - } - e - } else { - throw QueryCompilationErrors.noHandlerForUDAFError(clazz.getCanonicalName) - } + (input: Seq[Expression]) => functionExpressionBuilder.makeExpression(name, clazz, input) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 8289819644a0..7c80b4a80dae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -19,19 +19,22 @@ package org.apache.spark.sql.internal import org.apache.spark.annotation.Unstable import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry, ResolveSessionCatalog, TableFunctionRegistry} -import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.catalog.{FunctionExpressionBuilder, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{ColumnarRule, CommandExecutionMode, QueryExecution, SparkOptimizer, SparkPlan, SparkPlanner, SparkSqlParser} -import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg +import org.apache.spark.sql.execution.aggregate.{ResolveEncodersInScalaAgg, ScalaUDAF} import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.command.CommandCheck import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.{TableCapabilityCheck, V2SessionCatalog} import org.apache.spark.sql.execution.streaming.ResolveWriteToStream +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -153,7 +156,8 @@ abstract class BaseSessionStateBuilder( tableFunctionRegistry, SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), sqlParser, - resourceLoader) + resourceLoader, + new SparkUDFExpressionBuilder) parentState.foreach(_.catalog.copyStateTo(catalog)) catalog } @@ -392,3 +396,22 @@ private[sql] trait WithTestConf { self: BaseSessionStateBuilder => } } } + +class SparkUDFExpressionBuilder extends FunctionExpressionBuilder { + override def makeExpression(name: String, clazz: Class[_], input: Seq[Expression]): Expression = { + if (classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) { + val expr = ScalaUDAF( + input, + clazz.getConstructor().newInstance().asInstanceOf[UserDefinedAggregateFunction], + udafName = Some(name)) + // Check input argument size + if (expr.inputTypes.size != input.size) { + throw QueryCompilationErrors.invalidFunctionArgumentsError( + name, expr.inputTypes.size.toString, input.size) + } + expr + } else { + throw QueryCompilationErrors.noHandlerForUDAFError(clazz.getCanonicalName) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 56818b519133..488890a74a10 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -17,27 +17,20 @@ package org.apache.spark.sql.hive -import java.lang.reflect.InvocationTargetException import java.util.Locale import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} -import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.types.{DecimalType, DoubleType} -import org.apache.spark.util.Utils - private[sql] class HiveSessionCatalog( externalCatalogBuilder: () => ExternalCatalog, @@ -47,86 +40,17 @@ private[sql] class HiveSessionCatalog( tableFunctionRegistry: TableFunctionRegistry, hadoopConf: Configuration, parser: ParserInterface, - functionResourceLoader: FunctionResourceLoader) + functionResourceLoader: FunctionResourceLoader, + functionExpressionBuilder: FunctionExpressionBuilder) extends SessionCatalog( - externalCatalogBuilder, - globalTempViewManagerBuilder, - functionRegistry, - tableFunctionRegistry, - hadoopConf, - parser, - functionResourceLoader) { - - private def makeHiveFunctionExpression( - name: String, - clazz: Class[_], - input: Seq[Expression]): Expression = { - var udfExpr: Option[Expression] = None - try { - // When we instantiate hive UDF wrapper class, we may throw exception if the input - // expressions don't satisfy the hive UDF, such as type mismatch, input number - // mismatch, etc. Here we catch the exception and throw AnalysisException instead. - if (classOf[UDF].isAssignableFrom(clazz)) { - udfExpr = Some(HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), input)) - udfExpr.get.dataType // Force it to check input data types. - } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { - udfExpr = Some(HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), input)) - udfExpr.get.dataType // Force it to check input data types. - } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { - udfExpr = Some(HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), input)) - udfExpr.get.dataType // Force it to check input data types. - } else if (classOf[UDAF].isAssignableFrom(clazz)) { - udfExpr = Some(HiveUDAFFunction( - name, - new HiveFunctionWrapper(clazz.getName), - input, - isUDAFBridgeRequired = true)) - udfExpr.get.dataType // Force it to check input data types. - } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { - udfExpr = Some(HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), input)) - // Force it to check data types. - udfExpr.get.asInstanceOf[HiveGenericUDTF].elementSchema - } - } catch { - case NonFatal(exception) => - val e = exception match { - case i: InvocationTargetException => i.getCause - case o => o - } - val errorMsg = s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e" - val analysisException = new AnalysisException(errorMsg) - analysisException.setStackTrace(e.getStackTrace) - throw analysisException - } - udfExpr.getOrElse { - throw new InvalidUDFClassException( - s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'") - } - } - - /** - * Constructs a [[Expression]] based on the provided class that represents a function. - * - * This performs reflection to decide what type of [[Expression]] to return in the builder. - */ - override def makeFunctionExpression( - name: String, - clazz: Class[_], - input: Seq[Expression]): Expression = { - // Current thread context classloader may not be the one loaded the class. Need to switch - // context classloader to initialize instance properly. - Utils.withContextClassLoader(clazz.getClassLoader) { - try { - super.makeFunctionExpression(name, clazz, input) - } catch { - // If `super.makeFunctionExpression` throw `InvalidUDFClassException`, we construct - // Hive UDF/UDAF/UDTF with function definition. Otherwise, we just throw it earlier. - case _: InvalidUDFClassException => - makeHiveFunctionExpression(name, clazz, input) - case NonFatal(e) => throw e - } - } - } + externalCatalogBuilder, + globalTempViewManagerBuilder, + functionRegistry, + tableFunctionRegistry, + hadoopConf, + parser, + functionResourceLoader, + functionExpressionBuilder) { override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { try { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 7bf9b283de99..4f8f15ff4617 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -17,9 +17,17 @@ package org.apache.spark.sql.hive +import java.lang.reflect.InvocationTargetException + +import scala.util.control.NonFatal + +import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} +import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.{Analyzer, ResolveSessionCatalog} -import org.apache.spark.sql.catalyst.catalog.ExternalCatalogWithListener +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalogWithListener, InvalidUDFClassException} +import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner @@ -29,9 +37,10 @@ import org.apache.spark.sql.execution.command.CommandCheck import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.v2.TableCapabilityCheck import org.apache.spark.sql.execution.streaming.ResolveWriteToStream +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.execution.PruneHiveTablePartitions -import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} +import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState, SparkUDFExpressionBuilder} import org.apache.spark.util.Utils /** @@ -64,7 +73,8 @@ class HiveSessionStateBuilder( tableFunctionRegistry, SessionState.newHadoopConf(session.sparkContext.hadoopConfiguration, conf), sqlParser, - resourceLoader) + resourceLoader, + HiveUDFExpressionBuilder) parentState.foreach(_.catalog.copyStateTo(catalog)) catalog } @@ -133,3 +143,68 @@ class HiveSessionResourceLoader( } } } + +object HiveUDFExpressionBuilder extends SparkUDFExpressionBuilder { + override def makeExpression(name: String, clazz: Class[_], input: Seq[Expression]): Expression = { + // Current thread context classloader may not be the one loaded the class. Need to switch + // context classloader to initialize instance properly. + Utils.withContextClassLoader(clazz.getClassLoader) { + try { + super.makeExpression(name, clazz, input) + } catch { + // If `super.makeFunctionExpression` throw `InvalidUDFClassException`, we construct + // Hive UDF/UDAF/UDTF with function definition. Otherwise, we just throw it earlier. + case _: InvalidUDFClassException => + makeHiveFunctionExpression(name, clazz, input) + case NonFatal(e) => throw e + } + } + } + + private def makeHiveFunctionExpression( + name: String, + clazz: Class[_], + input: Seq[Expression]): Expression = { + var udfExpr: Option[Expression] = None + try { + // When we instantiate hive UDF wrapper class, we may throw exception if the input + // expressions don't satisfy the hive UDF, such as type mismatch, input number + // mismatch, etc. Here we catch the exception and throw AnalysisException instead. + if (classOf[UDF].isAssignableFrom(clazz)) { + udfExpr = Some(HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. + } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { + udfExpr = Some(HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. + } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { + udfExpr = Some(HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. + } else if (classOf[UDAF].isAssignableFrom(clazz)) { + udfExpr = Some(HiveUDAFFunction( + name, + new HiveFunctionWrapper(clazz.getName), + input, + isUDAFBridgeRequired = true)) + udfExpr.get.dataType // Force it to check input data types. + } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { + udfExpr = Some(HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), input)) + // Force it to check data types. + udfExpr.get.asInstanceOf[HiveGenericUDTF].elementSchema + } + } catch { + case NonFatal(exception) => + val e = exception match { + case i: InvocationTargetException => i.getCause + case o => o + } + val errorMsg = s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e" + val analysisException = new AnalysisException(errorMsg) + analysisException.setStackTrace(e.getStackTrace) + throw analysisException + } + udfExpr.getOrElse { + throw new InvalidUDFClassException( + s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'") + } + } +}