From fc47d96c53b0469e973c5d9d0aa320b3782eee6a Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Fri, 15 Oct 2021 19:12:45 +0800 Subject: [PATCH 1/5] test --- .../spark/sql/hive/HiveSessionCatalog.scala | 42 ++++++++++++++++--- .../sql/hive/execution/HiveSQLViewSuite.scala | 28 +++++++++++++ 2 files changed, 65 insertions(+), 5 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 56818b519133..9a6f8b52e919 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -22,20 +22,21 @@ import java.util.Locale import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} +import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper -import org.apache.spark.sql.types.{DecimalType, DoubleType} +import org.apache.spark.sql.types.{DecimalType, DoubleType, StructField, StructType} import org.apache.spark.util.Utils @@ -122,7 +123,38 @@ private[sql] class HiveSessionCatalog( // If `super.makeFunctionExpression` throw `InvalidUDFClassException`, we construct // Hive UDF/UDAF/UDTF with function definition. Otherwise, we just throw it earlier. case _: InvalidUDFClassException => - makeHiveFunctionExpression(name, clazz, input) + val clsForAggregator = + Utils.classForName("org.apache.spark.sql.expressions.Aggregator") + if (clsForAggregator.isAssignableFrom(clazz)) { + val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaAggregator") + val clsForEncoder = + Utils.classForName("org.apache.spark.sql.catalyst.encoders.ExpressionEncoder") + val aggregator = clazz.getConstructor().newInstance().asInstanceOf[Aggregator[_, _, _]] + val schema = StructType(input.map(e => StructField("temp", e.dataType))) + val e = cls.getConstructor(classOf[Seq[Expression]], clsForAggregator, + clsForEncoder, clsForEncoder, classOf[Boolean], classOf[Boolean], + classOf[Int], classOf[Int], classOf[Option[String]]) + .newInstance( + input, + aggregator, + RowEncoder(schema), + aggregator.bufferEncoder, + Boolean.box(true), + Boolean.box(true), + Int.box(1), + Int.box(1), + Some(name)) + .asInstanceOf[ImplicitCastInputTypes] + + // Check input argument size + if (e.inputTypes.size != input.size) { + throw QueryCompilationErrors.invalidFunctionArgumentsError( + name, e.inputTypes.size.toString, input.size) + } + e + } else { + makeHiveFunctionExpression(name, clazz, input) + } case NonFatal(e) => throw e } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala index 26ea87bcc1cf..bdf15b10897b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala @@ -179,4 +179,32 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { } } } + + test("SPARK-33692: 2") { + val avgFuncClass = "test.org.apache.spark.sql.MyDoubleAverage" + val sumFuncClass = "test.org.apache.spark.sql.MyDoubleSum" + val functionName = "test_udf" + withTempDatabase { dbName => + withUserDefinedFunction( + s"default.$functionName" -> false, + s"$dbName.$functionName" -> false, + functionName -> true) { + // create a function in default database + sql("USE DEFAULT") + sql(s"CREATE FUNCTION $functionName AS '$avgFuncClass'") + // create a view using a function in 'default' database +// val viewName = createView( +// "v1", s"SELECT $functionName(col1) AS func FROM VALUES (1), (2), (3)") +// // create function in another database with the same function name +// sql(s"USE $dbName") +// sql(s"CREATE FUNCTION $functionName AS '$sumFuncClass'") +// // create temporary function with the same function name +// sql(s"CREATE TEMPORARY FUNCTION $functionName AS '$sumFuncClass'") +// withView(viewName) { +// // view v1 should still using function defined in `default` database +// checkViewOutput(viewName, Seq(Row(102.0))) +// } + } + } + } } From 3702b9b1ac974810fff4491d2fb0130712578306 Mon Sep 17 00:00:00 2001 From: beliefer Date: Sun, 17 Oct 2021 00:57:04 +0800 Subject: [PATCH 2/5] Update code --- .../spark/sql/hive/HiveSessionCatalog.scala | 39 ++++++++++++++----- .../sql/hive/execution/HiveSQLViewSuite.scala | 22 ++++------- 2 files changed, 38 insertions(+), 23 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 9a6f8b52e919..97741d9ac6bf 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -20,23 +20,27 @@ package org.apache.spark.sql.hive import java.lang.reflect.InvocationTargetException import java.util.Locale +import scala.reflect.ClassTag +import scala.reflect.runtime.universe._ import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} + import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.encoders.RowEncoder +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper -import org.apache.spark.sql.types.{DecimalType, DoubleType, StructField, StructType} +import org.apache.spark.sql.types.{DecimalType, DoubleType} import org.apache.spark.util.Utils @@ -120,24 +124,41 @@ private[sql] class HiveSessionCatalog( try { super.makeFunctionExpression(name, clazz, input) } catch { - // If `super.makeFunctionExpression` throw `InvalidUDFClassException`, we construct - // Hive UDF/UDAF/UDTF with function definition. Otherwise, we just throw it earlier. + // If `super.makeFunctionExpression` throw `InvalidUDFClassException`, we try to construct + // ScalaAggregator or Hive UDF/UDAF/UDTF with function definition. Otherwise, + // we just throw it earlier. + // Unfortunately we need to use reflection here because Aggregator + // and ScalaAggregator are defined in sql/core module. case _: InvalidUDFClassException => val clsForAggregator = Utils.classForName("org.apache.spark.sql.expressions.Aggregator") if (clsForAggregator.isAssignableFrom(clazz)) { - val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaAggregator") + val clsForScalaAggregator = + Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaAggregator") val clsForEncoder = Utils.classForName("org.apache.spark.sql.catalyst.encoders.ExpressionEncoder") val aggregator = clazz.getConstructor().newInstance().asInstanceOf[Aggregator[_, _, _]] - val schema = StructType(input.map(e => StructField("temp", e.dataType))) - val e = cls.getConstructor(classOf[Seq[Expression]], clsForAggregator, + // Construct the input encoder + val mirror = runtimeMirror(clazz.getClassLoader) + val classType = mirror.classSymbol(clazz) + val baseClassType = typeOf[Aggregator[_, _, _]].typeSymbol.asClass + val baseType = internal.thisType(classType).baseType(baseClassType) + val tpe = baseType.typeArgs.head + val cls = mirror.runtimeClass(tpe) + val serializer = ScalaReflection.serializerForType(tpe) + val deserializer = ScalaReflection.deserializerForType(tpe) + val inputEncoder = new ExpressionEncoder( + serializer, + deserializer, + ClassTag(cls)) + + val e = clsForScalaAggregator.getConstructor(classOf[Seq[Expression]], clsForAggregator, clsForEncoder, clsForEncoder, classOf[Boolean], classOf[Boolean], classOf[Int], classOf[Int], classOf[Option[String]]) .newInstance( input, aggregator, - RowEncoder(schema), + inputEncoder, aggregator.bufferEncoder, Boolean.box(true), Boolean.box(true), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala index bdf15b10897b..f262ba50c428 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.hive.HiveUtils import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.{NullType, StructType} +class MyDoubleAverage extends MyDoubleAvgAggBase + /** * A test suite for Hive view related functionality. */ @@ -180,9 +182,8 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { } } - test("SPARK-33692: 2") { - val avgFuncClass = "test.org.apache.spark.sql.MyDoubleAverage" - val sumFuncClass = "test.org.apache.spark.sql.MyDoubleSum" + test("SPARK-37018: Spark SQL should support create function with Aggregator") { + val avgFuncClass = "org.apache.spark.sql.hive.execution.MyDoubleAverage" val functionName = "test_udf" withTempDatabase { dbName => withUserDefinedFunction( @@ -193,17 +194,10 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { sql("USE DEFAULT") sql(s"CREATE FUNCTION $functionName AS '$avgFuncClass'") // create a view using a function in 'default' database -// val viewName = createView( -// "v1", s"SELECT $functionName(col1) AS func FROM VALUES (1), (2), (3)") -// // create function in another database with the same function name -// sql(s"USE $dbName") -// sql(s"CREATE FUNCTION $functionName AS '$sumFuncClass'") -// // create temporary function with the same function name -// sql(s"CREATE TEMPORARY FUNCTION $functionName AS '$sumFuncClass'") -// withView(viewName) { -// // view v1 should still using function defined in `default` database -// checkViewOutput(viewName, Seq(Row(102.0))) -// } + withView("v1") { + sql(s"CREATE VIEW v1 AS SELECT $functionName(col1) AS func FROM VALUES (1), (2), (3)") + checkAnswer(sql(s"SELECT * FROM v1"), Seq(Row(102.0))) + } } } } From 7b11c6e67e4dac72b35c08647855de1da01b0490 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Mon, 18 Oct 2021 12:05:15 +0800 Subject: [PATCH 3/5] Update sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala Co-authored-by: Hyukjin Kwon --- .../org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala index f262ba50c428..8fe90449e1f7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveSQLViewSuite.scala @@ -196,7 +196,7 @@ class HiveSQLViewSuite extends SQLViewSuite with TestHiveSingleton { // create a view using a function in 'default' database withView("v1") { sql(s"CREATE VIEW v1 AS SELECT $functionName(col1) AS func FROM VALUES (1), (2), (3)") - checkAnswer(sql(s"SELECT * FROM v1"), Seq(Row(102.0))) + checkAnswer(sql("SELECT * FROM v1"), Seq(Row(102.0))) } } } From 136b8f50e699ea026144c434dda1e9cffe90ae64 Mon Sep 17 00:00:00 2001 From: gengjiaan Date: Mon, 18 Oct 2021 17:16:59 +0800 Subject: [PATCH 4/5] Update code --- .../spark/sql/hive/HiveSessionCatalog.scala | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 97741d9ac6bf..7fe7e401b099 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.aggregate.ScalaAggregator import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.types.{DecimalType, DoubleType} @@ -127,16 +128,10 @@ private[sql] class HiveSessionCatalog( // If `super.makeFunctionExpression` throw `InvalidUDFClassException`, we try to construct // ScalaAggregator or Hive UDF/UDAF/UDTF with function definition. Otherwise, // we just throw it earlier. - // Unfortunately we need to use reflection here because Aggregator - // and ScalaAggregator are defined in sql/core module. case _: InvalidUDFClassException => - val clsForAggregator = - Utils.classForName("org.apache.spark.sql.expressions.Aggregator") + val clsForAggregator = classOf[Aggregator[_, _, _]] if (clsForAggregator.isAssignableFrom(clazz)) { - val clsForScalaAggregator = - Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaAggregator") - val clsForEncoder = - Utils.classForName("org.apache.spark.sql.catalyst.encoders.ExpressionEncoder") + val clsForEncoder = classOf[ExpressionEncoder[_]] val aggregator = clazz.getConstructor().newInstance().asInstanceOf[Aggregator[_, _, _]] // Construct the input encoder val mirror = runtimeMirror(clazz.getClassLoader) @@ -152,8 +147,8 @@ private[sql] class HiveSessionCatalog( deserializer, ClassTag(cls)) - val e = clsForScalaAggregator.getConstructor(classOf[Seq[Expression]], clsForAggregator, - clsForEncoder, clsForEncoder, classOf[Boolean], classOf[Boolean], + val e = classOf[ScalaAggregator[_, _, _]].getConstructor(classOf[Seq[Expression]], + clsForAggregator, clsForEncoder, clsForEncoder, classOf[Boolean], classOf[Boolean], classOf[Int], classOf[Int], classOf[Option[String]]) .newInstance( input, From 815c65c21c0a83b74a05ce42bf2e1c75b5bc4f56 Mon Sep 17 00:00:00 2001 From: beliefer Date: Tue, 19 Oct 2021 11:51:36 +0800 Subject: [PATCH 5/5] Update code --- .../spark/sql/hive/HiveSessionCatalog.scala | 25 ++++++------------- 1 file changed, 7 insertions(+), 18 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 7fe7e401b099..22c79731e4a7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ImplicitCastInputTypes} +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.aggregate.ScalaAggregator @@ -131,8 +131,8 @@ private[sql] class HiveSessionCatalog( case _: InvalidUDFClassException => val clsForAggregator = classOf[Aggregator[_, _, _]] if (clsForAggregator.isAssignableFrom(clazz)) { - val clsForEncoder = classOf[ExpressionEncoder[_]] - val aggregator = clazz.getConstructor().newInstance().asInstanceOf[Aggregator[_, _, _]] + val aggregator = + clazz.getConstructor().newInstance().asInstanceOf[Aggregator[Any, Any, Any]] // Construct the input encoder val mirror = runtimeMirror(clazz.getClassLoader) val classType = mirror.classSymbol(clazz) @@ -142,25 +142,14 @@ private[sql] class HiveSessionCatalog( val cls = mirror.runtimeClass(tpe) val serializer = ScalaReflection.serializerForType(tpe) val deserializer = ScalaReflection.deserializerForType(tpe) - val inputEncoder = new ExpressionEncoder( + val inputEncoder = new ExpressionEncoder[Any]( serializer, deserializer, ClassTag(cls)) - val e = classOf[ScalaAggregator[_, _, _]].getConstructor(classOf[Seq[Expression]], - clsForAggregator, clsForEncoder, clsForEncoder, classOf[Boolean], classOf[Boolean], - classOf[Int], classOf[Int], classOf[Option[String]]) - .newInstance( - input, - aggregator, - inputEncoder, - aggregator.bufferEncoder, - Boolean.box(true), - Boolean.box(true), - Int.box(1), - Int.box(1), - Some(name)) - .asInstanceOf[ImplicitCastInputTypes] + val e = new ScalaAggregator[Any, Any, Any](input, aggregator, inputEncoder, + aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[Any]], + aggregatorName = Some(name)) // Check input argument size if (e.inputTypes.size != input.size) {