From 40281551f461ecb5f3c1720d1ed45d885e5353a6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Fri, 21 Jul 2017 00:59:09 -0700 Subject: [PATCH 01/11] support UDAF --- .../sql/catalyst/catalog/SessionCatalog.scala | 51 ++++++++++++++++-- .../spark/sql/execution/aggregate/udaf.scala | 7 ++- .../test/resources/sql-tests/inputs/udaf.sql | 13 +++++ .../resources/sql-tests/results/udaf.sql.out | 54 +++++++++++++++++++ .../spark/sql/hive/HiveSessionCatalog.scala | 4 ++ .../sql/hive/execution/HiveUDAFSuite.scala | 13 +++++ 6 files changed, 137 insertions(+), 5 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/udaf.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/udaf.sql.out diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index b44d2ee69e1d..10eff5a42e47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.catalog +import java.lang.reflect.InvocationTargetException import java.net.URI import java.util.Locale import java.util.concurrent.Callable @@ -24,6 +25,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal import com.google.common.cache.{Cache, CacheBuilder} import org.apache.hadoop.conf.Configuration @@ -40,6 +42,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils object SessionCatalog { val DEFAULT_DATABASE = "default" @@ -1096,8 +1099,43 @@ class SessionCatalog( * This performs reflection to decide what type of [[Expression]] to return in the builder. */ protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { - // TODO: at least support UDAFs here - throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") + makeFunctionBuilder(name, Utils.classForName(functionClassName)) + } + + /** + * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + */ + private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = { + // When we instantiate ScalaUDAF class, we may throw exception if the input + // expressions don't satisfy the UDAF, such as type mismatch, input number + // mismatch, etc. Here we catch the exception and throw AnalysisException instead. + (children: Seq[Expression]) => { + try { + val clsForUDAF = + Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction") + if (clsForUDAF.isAssignableFrom(clazz)) { + val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") + // val ctor = classOf[Integer].getConstructor(classOf[Int]) + cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) + .newInstance(children, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) + .asInstanceOf[Expression] + } else { + throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") + } + } catch { + case NonFatal(exception) => + val e = exception match { + // Since we are using shim, the exceptions thrown by the underlying method of + // Method.invoke() are wrapped by InvocationTargetException + case i: InvocationTargetException => i.getCause + case o => o + } + val analysisException = + new AnalysisException(s"No handler for UDAF '${clazz.getCanonicalName}': $e") + analysisException.setStackTrace(e.getStackTrace) + throw analysisException + } + } } /** @@ -1116,12 +1154,17 @@ class SessionCatalog( overrideIfExists: Boolean, functionBuilder: Option[FunctionBuilder] = None): Unit = { val func = funcDefinition.identifier + val className = funcDefinition.className if (functionRegistry.functionExists(func) && !overrideIfExists) { throw new AnalysisException(s"Function $func already exists") } - val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName) + if (!Utils.classIsLoadable(className)) { + throw new AnalysisException(s"Can not load class '$className' when registering " + + s"the function '$func', please make sure it is on the classpath") + } + val info = new ExpressionInfo(className, func.database.orNull, func.funcName) val builder = - functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, funcDefinition.className)) + functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, className)) functionRegistry.registerFunction(func, info, builder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index ae5e2c6bece2..a5984ba121de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.internal.Logging -import org.apache.spark.sql.Row +import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, _} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate @@ -326,6 +326,11 @@ case class ScalaUDAF( inputAggBufferOffset: Int = 0) extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes { + if (children.length != udaf.inputSchema.length) { + throw new AnalysisException(s"Invalid number of arguments for the function " + + s"Expected: ${udaf.inputSchema.length}; Found: ${children.length}") + } + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/core/src/test/resources/sql-tests/inputs/udaf.sql b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql new file mode 100644 index 000000000000..2183ba23afc3 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/udaf.sql @@ -0,0 +1,13 @@ +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +(1), (2), (3), (4) +as t1(int_col1); + +CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg'; + +SELECT default.myDoubleAvg(int_col1) as my_avg from t1; + +SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1; + +CREATE FUNCTION udaf1 AS 'test.non.existent.udaf'; + +SELECT default.udaf1(int_col1) as udaf1 from t1; diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out new file mode 100644 index 000000000000..bd4364ae05bf --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -0,0 +1,54 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW t1 AS SELECT * FROM VALUES +(1), (2), (3), (4) +as t1(int_col1) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE FUNCTION myDoubleAvg AS 'test.org.apache.spark.sql.MyDoubleAvg' +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +SELECT default.myDoubleAvg(int_col1) as my_avg from t1 +-- !query 2 schema +struct +-- !query 2 output +102.5 + + +-- !query 3 +SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1 +-- !query 3 schema +struct<> +-- !query 3 output +org.apache.spark.sql.AnalysisException +No handler for UDAF 'test.org.apache.spark.sql.MyDoubleAvg': org.apache.spark.sql.AnalysisException: Invalid number of arguments for the function Expected: 1; Found: 2;; line 1 pos 7 + + +-- !query 4 +CREATE FUNCTION udaf1 AS 'test.non.existent.udaf' +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +SELECT default.udaf1(int_col1) as udaf1 from t1 +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Can not load class 'test.non.existent.udaf' when registering the function 'default.udaf1', please make sure it is on the classpath; line 1 pos 7 diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 0d0269f69430..f689329c986a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -34,6 +34,8 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, DoubleType} @@ -95,6 +97,8 @@ private[sql] class HiveSessionCatalog( val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children) udtf.elementSchema // Force it to check input data types. udtf + } else if (classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) { + ScalaUDAF(children, clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction]) } else { throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala index 479ca1e8def5..8986fb58c646 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDAFSuite.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.ql.util.JavaDataModel import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo +import test.org.apache.spark.sql.MyDoubleAvg import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec @@ -86,6 +87,18 @@ class HiveUDAFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { )) } + test("call JAVA UDAF") { + withTempView("temp") { + withUserDefinedFunction("myDoubleAvg" -> false) { + spark.range(1, 10).toDF("value").createOrReplaceTempView("temp") + sql(s"CREATE FUNCTION myDoubleAvg AS '${classOf[MyDoubleAvg].getName}'") + checkAnswer( + spark.sql("SELECT default.myDoubleAvg(value) as my_avg from temp"), + Row(105.0)) + } + } + } + test("non-deterministic children expressions of UDAF") { withTempView("view1") { spark.range(1).selectExpr("id as x", "id as y").createTempView("view1") From a65607c2e31762b7f17c7ea64248c504557350d8 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 12 Aug 2017 18:00:07 -0700 Subject: [PATCH 02/11] fix. --- .../sql/catalyst/catalog/SessionCatalog.scala | 16 +++++++++------- .../spark/sql/execution/aggregate/udaf.scala | 7 +------ .../resources/sql-tests/results/udaf.sql.out | 4 ++-- 3 files changed, 12 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 10eff5a42e47..03694b04bdbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1154,17 +1154,19 @@ class SessionCatalog( overrideIfExists: Boolean, functionBuilder: Option[FunctionBuilder] = None): Unit = { val func = funcDefinition.identifier - val className = funcDefinition.className if (functionRegistry.functionExists(func) && !overrideIfExists) { throw new AnalysisException(s"Function $func already exists") } - if (!Utils.classIsLoadable(className)) { - throw new AnalysisException(s"Can not load class '$className' when registering " + - s"the function '$func', please make sure it is on the classpath") - } - val info = new ExpressionInfo(className, func.database.orNull, func.funcName) + val info = new ExpressionInfo(funcDefinition.className, func.database.orNull, func.funcName) val builder = - functionBuilder.getOrElse(makeFunctionBuilder(func.unquotedString, className)) + functionBuilder.getOrElse { + val className = funcDefinition.className + if (!Utils.classIsLoadable(className)) { + throw new AnalysisException(s"Can not load class '$className' when registering " + + s"the function '$func', please make sure it is on the classpath") + } + makeFunctionBuilder(func.unquotedString, className) + } functionRegistry.registerFunction(func, info, builder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index a5984ba121de..ae5e2c6bece2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.internal.Logging -import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, _} import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate @@ -326,11 +326,6 @@ case class ScalaUDAF( inputAggBufferOffset: Int = 0) extends ImperativeAggregate with NonSQLExpression with Logging with ImplicitCastInputTypes { - if (children.length != udaf.inputSchema.length) { - throw new AnalysisException(s"Invalid number of arguments for the function " + - s"Expected: ${udaf.inputSchema.length}; Found: ${children.length}") - } - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out index bd4364ae05bf..4815a578b102 100644 --- a/sql/core/src/test/resources/sql-tests/results/udaf.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udaf.sql.out @@ -33,8 +33,8 @@ SELECT default.myDoubleAvg(int_col1, 3) as my_avg from t1 -- !query 3 schema struct<> -- !query 3 output -org.apache.spark.sql.AnalysisException -No handler for UDAF 'test.org.apache.spark.sql.MyDoubleAvg': org.apache.spark.sql.AnalysisException: Invalid number of arguments for the function Expected: 1; Found: 2;; line 1 pos 7 +java.lang.AssertionError +assertion failed: Incorrect number of children -- !query 4 From 12cefc24ebd73c5f986bea50d0d30ca7ff791ae2 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 12 Aug 2017 18:02:11 -0700 Subject: [PATCH 03/11] fix. --- .../org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 03694b04bdbe..58df686f9749 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1115,7 +1115,6 @@ class SessionCatalog( Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction") if (clsForUDAF.isAssignableFrom(clazz)) { val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") - // val ctor = classOf[Integer].getConstructor(classOf[Int]) cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) .newInstance(children, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) .asInstanceOf[Expression] From bd5ae2616d67d42eacfbab110b684f2219afb6d6 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 21 Aug 2017 00:31:25 -0700 Subject: [PATCH 04/11] fix. --- .../sql/catalyst/catalog/SessionCatalog.scala | 39 +++++---- .../spark/sql/hive/HiveSessionCatalog.scala | 85 ++++++++++--------- 2 files changed, 66 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 58df686f9749..0e6fafd9bb94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1099,26 +1099,10 @@ class SessionCatalog( * This performs reflection to decide what type of [[Expression]] to return in the builder. */ protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { - makeFunctionBuilder(name, Utils.classForName(functionClassName)) - } - - /** - * Construct a [[FunctionBuilder]] based on the provided class that represents a function. - */ - private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = { - // When we instantiate ScalaUDAF class, we may throw exception if the input - // expressions don't satisfy the UDAF, such as type mismatch, input number - // mismatch, etc. Here we catch the exception and throw AnalysisException instead. + val clazz = Utils.classForName(functionClassName) (children: Seq[Expression]) => { try { - val clsForUDAF = - Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction") - if (clsForUDAF.isAssignableFrom(clazz)) { - val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") - cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) - .newInstance(children, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) - .asInstanceOf[Expression] - } else { + makeFunctionExpression(name, Utils.classForName(functionClassName), children).getOrElse { throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") } } catch { @@ -1137,6 +1121,25 @@ class SessionCatalog( } } + /** + * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + */ + protected def makeFunctionExpression( + name: String, + clazz: Class[_], + children: Seq[Expression]): Option[Expression] = { + val clsForUDAF = + Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction") + if (clsForUDAF.isAssignableFrom(clazz)) { + val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") + Some(cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) + .newInstance(children, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) + .asInstanceOf[Expression]) + } else { + None + } + } + /** * Loads resources such as JARs and Files for a function. Every resource is represented * by a tuple (resource type, resource uri). diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index f689329c986a..57b85ee54c7a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -34,8 +34,6 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, FunctionResourceLoader, GlobalTempViewManager, SessionCatalog} import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.execution.aggregate.ScalaUDAF -import org.apache.spark.sql.expressions.UserDefinedAggregateFunction import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, DoubleType} @@ -60,46 +58,11 @@ private[sql] class HiveSessionCatalog( parser, functionResourceLoader) { - override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = { - makeFunctionBuilder(funcName, Utils.classForName(className)) - } - - /** - * Construct a [[FunctionBuilder]] based on the provided class that represents a function. - */ - private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = { - // When we instantiate hive UDF wrapper class, we may throw exception if the input - // expressions don't satisfy the hive UDF, such as type mismatch, input number - // mismatch, etc. Here we catch the exception and throw AnalysisException instead. + override def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + val clazz = Utils.classForName(functionClassName) (children: Seq[Expression]) => { try { - if (classOf[UDF].isAssignableFrom(clazz)) { - val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), children) - udf.dataType // Force it to check input data types. - udf - } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { - val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), children) - udf.dataType // Force it to check input data types. - udf - } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { - val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), children) - udaf.dataType // Force it to check input data types. - udaf - } else if (classOf[UDAF].isAssignableFrom(clazz)) { - val udaf = HiveUDAFFunction( - name, - new HiveFunctionWrapper(clazz.getName), - children, - isUDAFBridgeRequired = true) - udaf.dataType // Force it to check input data types. - udaf - } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { - val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children) - udtf.elementSchema // Force it to check input data types. - udtf - } else if (classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) { - ScalaUDAF(children, clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction]) - } else { + makeFunctionExpression(name, Utils.classForName(functionClassName), children).getOrElse { throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'") } } catch { @@ -114,6 +77,48 @@ private[sql] class HiveSessionCatalog( } } + /** + * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + */ + override def makeFunctionExpression( + name: String, + clazz: Class[_], + children: Seq[Expression]): Option[Expression] = { + + super.makeFunctionExpression(name, clazz, children).orElse { + // When we instantiate hive UDF wrapper class, we may throw exception if the input + // expressions don't satisfy the hive UDF, such as type mismatch, input number + // mismatch, etc. Here we catch the exception and throw AnalysisException instead. + if (classOf[UDF].isAssignableFrom(clazz)) { + val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), children) + udf.dataType // Force it to check input data types. + Some(udf) + } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { + val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), children) + udf.dataType // Force it to check input data types. + Some(udf) + } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { + val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), children) + udaf.dataType // Force it to check input data types. + Some(udaf) + } else if (classOf[UDAF].isAssignableFrom(clazz)) { + val udaf = HiveUDAFFunction( + name, + new HiveFunctionWrapper(clazz.getName), + children, + isUDAFBridgeRequired = true) + udaf.dataType // Force it to check input data types. + Some(udaf) + } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { + val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children) + udtf.elementSchema // Force it to check input data types. + Some(udtf) + } else { + None + } + } + } + override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = { try { lookupFunction0(name, children) From 7251be9386656e590c86c75be466779bdd2e076d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Mon, 21 Aug 2017 10:14:03 -0700 Subject: [PATCH 05/11] fix. --- .../sql/catalyst/catalog/SessionCatalog.scala | 31 ++++++---- .../spark/sql/hive/HiveSessionCatalog.scala | 21 ------- .../sql/hive/execution/HiveUDFSuite.scala | 61 ++++++------------- 3 files changed, 37 insertions(+), 76 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 0e6fafd9bb94..5fc5c25436d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -129,6 +130,13 @@ class SessionCatalog( if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } + /** + * Checks whether the Hive metastore is being used + */ + private def isUsingHiveMetastore: Boolean = { + conf.getConf(CATALOG_IMPLEMENTATION).toLowerCase(Locale.ROOT) == "hive" + } + private val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { val cacheSize = conf.tableRelationCacheSize CacheBuilder.newBuilder().maximumSize(cacheSize).build[QualifiedTableName, LogicalPlan]() @@ -1094,27 +1102,24 @@ class SessionCatalog( // ---------------------------------------------------------------- /** - * Construct a [[FunctionBuilder]] based on the provided class that represents a function. - * - * This performs reflection to decide what type of [[Expression]] to return in the builder. + * Constructs a [[FunctionBuilder]] based on the provided class that represents a function. */ protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { val clazz = Utils.classForName(functionClassName) (children: Seq[Expression]) => { try { makeFunctionExpression(name, Utils.classForName(functionClassName), children).getOrElse { - throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") + val extraMsg = + if (!isUsingHiveMetastore) "Use sparkSession.udf.register(...) instead." else "" + throw new AnalysisException( + s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'. $extraMsg") } } catch { - case NonFatal(exception) => - val e = exception match { - // Since we are using shim, the exceptions thrown by the underlying method of - // Method.invoke() are wrapped by InvocationTargetException - case i: InvocationTargetException => i.getCause - case o => o - } + case ae: AnalysisException => + throw ae + case NonFatal(e) => val analysisException = - new AnalysisException(s"No handler for UDAF '${clazz.getCanonicalName}': $e") + new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e") analysisException.setStackTrace(e.getStackTrace) throw analysisException } @@ -1123,6 +1128,8 @@ class SessionCatalog( /** * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + * + * This performs reflection to decide what type of [[Expression]] to return in the builder. */ protected def makeFunctionExpression( name: String, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 57b85ee54c7a..d4071ded85f3 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -27,7 +27,6 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder @@ -37,7 +36,6 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DecimalType, DoubleType} -import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( @@ -58,25 +56,6 @@ private[sql] class HiveSessionCatalog( parser, functionResourceLoader) { - override def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { - val clazz = Utils.classForName(functionClassName) - (children: Seq[Expression]) => { - try { - makeFunctionExpression(name, Utils.classForName(functionClassName), children).getOrElse { - throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'") - } - } catch { - case ae: AnalysisException => - throw ae - case NonFatal(e) => - val analysisException = - new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}': $e") - analysisException.setStackTrace(e.getStackTrace) - throw analysisException - } - } - } - /** * Construct a [[FunctionBuilder]] based on the provided class that represents a function. */ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index cae338c0ab0a..8322b82e7733 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -404,59 +404,34 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } test("Hive UDFs with insufficient number of input arguments should trigger an analysis error") { - Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("testUDF") + withTempView("testUDF") { + Seq((1, 2)).toDF("a", "b").createOrReplaceTempView("testUDF") + + def testErrorMsgForFunc(funcName: String, className: String): Unit = { + withUserDefinedFunction(funcName -> true) { + sql(s"CREATE TEMPORARY FUNCTION $funcName AS '$className'") + val message = intercept[AnalysisException] { + sql(s"SELECT $funcName() FROM testUDF") + }.getMessage + assert(message.contains(s"No handler for UDF/UDAF/UDTF '$className'")) + } + } - { // HiveSimpleUDF - sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDFTwoListList() FROM testUDF") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") - } + testErrorMsgForFunc("testUDFTwoListList", classOf[UDFTwoListList].getName) - { // HiveGenericUDF - sql(s"CREATE TEMPORARY FUNCTION testUDFAnd AS '${classOf[GenericUDFOPAnd].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDFAnd() FROM testUDF") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFAnd") - } + testErrorMsgForFunc("testUDFAnd", classOf[GenericUDFOPAnd].getName) - { // Hive UDAF - sql(s"CREATE TEMPORARY FUNCTION testUDAFPercentile AS '${classOf[UDAFPercentile].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDAFPercentile(a) FROM testUDF GROUP BY b") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFPercentile") - } + testErrorMsgForFunc("testUDAFPercentile", classOf[UDAFPercentile].getName) - { // AbstractGenericUDAFResolver - sql(s"CREATE TEMPORARY FUNCTION testUDAFAverage AS '${classOf[GenericUDAFAverage].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDAFAverage() FROM testUDF GROUP BY b") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFAverage") - } + testErrorMsgForFunc("testUDAFAverage", classOf[GenericUDAFAverage].getName) - { - // Hive UDTF - sql(s"CREATE TEMPORARY FUNCTION testUDTFExplode AS '${classOf[GenericUDTFExplode].getName}'") - val message = intercept[AnalysisException] { - sql("SELECT testUDTFExplode() FROM testUDF") - }.getMessage - assert(message.contains("No handler for Hive UDF")) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode") + // AbstractGenericUDAFResolver + testErrorMsgForFunc("testUDTFExplode", classOf[GenericUDTFExplode].getName) } - - spark.catalog.dropTempView("testUDF") } test("Hive UDF in group by") { From d3fbdc5d0f67a422395b76bd1035fe7fb95f7de1 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Aug 2017 00:21:14 -0700 Subject: [PATCH 06/11] typo --- .../org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 5fc5c25436d1..9e991dcfe148 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1108,7 +1108,7 @@ class SessionCatalog( val clazz = Utils.classForName(functionClassName) (children: Seq[Expression]) => { try { - makeFunctionExpression(name, Utils.classForName(functionClassName), children).getOrElse { + makeFunctionExpression(name, clazz, children).getOrElse { val extraMsg = if (!isUsingHiveMetastore) "Use sparkSession.udf.register(...) instead." else "" throw new AnalysisException( From 57607b5e175894b488e87752e41ff60cc2700045 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Aug 2017 08:03:51 -0700 Subject: [PATCH 07/11] fix. --- .../sql/catalyst/catalog/SessionCatalog.scala | 32 ++------ .../spark/sql/hive/HiveSessionCatalog.scala | 73 +++++++++++-------- 2 files changed, 50 insertions(+), 55 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 9e991dcfe148..b00513f4ec82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1104,26 +1104,9 @@ class SessionCatalog( /** * Constructs a [[FunctionBuilder]] based on the provided class that represents a function. */ - protected def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + private def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { val clazz = Utils.classForName(functionClassName) - (children: Seq[Expression]) => { - try { - makeFunctionExpression(name, clazz, children).getOrElse { - val extraMsg = - if (!isUsingHiveMetastore) "Use sparkSession.udf.register(...) instead." else "" - throw new AnalysisException( - s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'. $extraMsg") - } - } catch { - case ae: AnalysisException => - throw ae - case NonFatal(e) => - val analysisException = - new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e") - analysisException.setStackTrace(e.getStackTrace) - throw analysisException - } - } + (input: Seq[Expression]) => makeFunctionExpression(name, clazz, input) } /** @@ -1134,16 +1117,17 @@ class SessionCatalog( protected def makeFunctionExpression( name: String, clazz: Class[_], - children: Seq[Expression]): Option[Expression] = { + input: Seq[Expression]): Expression = { val clsForUDAF = Utils.classForName("org.apache.spark.sql.expressions.UserDefinedAggregateFunction") if (clsForUDAF.isAssignableFrom(clazz)) { val cls = Utils.classForName("org.apache.spark.sql.execution.aggregate.ScalaUDAF") - Some(cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) - .newInstance(children, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) - .asInstanceOf[Expression]) + cls.getConstructor(classOf[Seq[Expression]], clsForUDAF, classOf[Int], classOf[Int]) + .newInstance(input, clazz.newInstance().asInstanceOf[Object], Int.box(1), Int.box(1)) + .asInstanceOf[Expression] } else { - None + throw new AnalysisException(s"No handler for UDAF '${clazz.getCanonicalName}'. " + + s"Use sparkSession.udf.register(...) instead.") } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index d4071ded85f3..daddf8c8c787 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder @@ -62,38 +63,48 @@ private[sql] class HiveSessionCatalog( override def makeFunctionExpression( name: String, clazz: Class[_], - children: Seq[Expression]): Option[Expression] = { + input: Seq[Expression]): Expression = { - super.makeFunctionExpression(name, clazz, children).orElse { - // When we instantiate hive UDF wrapper class, we may throw exception if the input - // expressions don't satisfy the hive UDF, such as type mismatch, input number - // mismatch, etc. Here we catch the exception and throw AnalysisException instead. - if (classOf[UDF].isAssignableFrom(clazz)) { - val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), children) - udf.dataType // Force it to check input data types. - Some(udf) - } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { - val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), children) - udf.dataType // Force it to check input data types. - Some(udf) - } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { - val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), children) - udaf.dataType // Force it to check input data types. - Some(udaf) - } else if (classOf[UDAF].isAssignableFrom(clazz)) { - val udaf = HiveUDAFFunction( - name, - new HiveFunctionWrapper(clazz.getName), - children, - isUDAFBridgeRequired = true) - udaf.dataType // Force it to check input data types. - Some(udaf) - } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { - val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children) - udtf.elementSchema // Force it to check input data types. - Some(udtf) - } else { - None + Try(super.makeFunctionExpression(name, clazz, input)).getOrElse { + try { + // When we instantiate hive UDF wrapper class, we may throw exception if the input + // expressions don't satisfy the hive UDF, such as type mismatch, input number + // mismatch, etc. Here we catch the exception and throw AnalysisException instead. + if (classOf[UDF].isAssignableFrom(clazz)) { + val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), input) + udf.dataType // Force it to check input data types. + udf + } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { + val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), input) + udf.dataType // Force it to check input data types. + udf + } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { + val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), input) + udaf.dataType // Force it to check input data types. + udaf + } else if (classOf[UDAF].isAssignableFrom(clazz)) { + val udaf = HiveUDAFFunction( + name, + new HiveFunctionWrapper(clazz.getName), + input, + isUDAFBridgeRequired = true) + udaf.dataType // Force it to check input data types. + udaf + } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { + val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), input) + udtf.elementSchema // Force it to check input data types. + udtf + } else { + throw new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'") + } + } catch { + case ae: AnalysisException => + throw ae + case NonFatal(e) => + val analysisException = + new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e") + analysisException.setStackTrace(e.getStackTrace) + throw analysisException } } } From 05e8168af15abf3fe3a8448a73b1ff41d4a9d682 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Aug 2017 09:07:23 -0700 Subject: [PATCH 08/11] fix. --- .../sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../spark/sql/hive/HiveSessionCatalog.scala | 37 ++++++++----------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index b00513f4ec82..07660e12d938 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1110,7 +1110,7 @@ class SessionCatalog( } /** - * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + * Constructs a [[FunctionBuilder]] based on the provided class that represents a function. * * This performs reflection to decide what type of [[Expression]] to return in the builder. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index daddf8c8c787..04d80926583a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -58,7 +58,7 @@ private[sql] class HiveSessionCatalog( functionResourceLoader) { /** - * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + * Constructs a [[FunctionBuilder]] based on the provided class that represents a function. */ override def makeFunctionExpression( name: String, @@ -66,46 +66,41 @@ private[sql] class HiveSessionCatalog( input: Seq[Expression]): Expression = { Try(super.makeFunctionExpression(name, clazz, input)).getOrElse { + var udfExpr: Option[Expression] = None try { // When we instantiate hive UDF wrapper class, we may throw exception if the input // expressions don't satisfy the hive UDF, such as type mismatch, input number // mismatch, etc. Here we catch the exception and throw AnalysisException instead. if (classOf[UDF].isAssignableFrom(clazz)) { - val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), input) - udf.dataType // Force it to check input data types. - udf + udfExpr = Some(HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { - val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), input) - udf.dataType // Force it to check input data types. - udf + udfExpr = Some(HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { - val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), input) - udaf.dataType // Force it to check input data types. - udaf + udfExpr = Some(HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.dataType // Force it to check input data types. } else if (classOf[UDAF].isAssignableFrom(clazz)) { - val udaf = HiveUDAFFunction( + udfExpr = Some(HiveUDAFFunction( name, new HiveFunctionWrapper(clazz.getName), input, - isUDAFBridgeRequired = true) - udaf.dataType // Force it to check input data types. - udaf + isUDAFBridgeRequired = true)) + udfExpr.get.dataType // Force it to check input data types. } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { - val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), input) - udtf.elementSchema // Force it to check input data types. - udtf - } else { - throw new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'") + udfExpr = Some(HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), input)) + udfExpr.get.asInstanceOf[HiveGenericUDTF].elementSchema // Force it to check data types. } } catch { - case ae: AnalysisException => - throw ae case NonFatal(e) => val analysisException = new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e") analysisException.setStackTrace(e.getStackTrace) throw analysisException } + udfExpr.getOrElse { + throw new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}'") + } } } From aff8f9efd490296005607310c0faaf7970cc352d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Aug 2017 09:14:15 -0700 Subject: [PATCH 09/11] fix. --- .../sql/hive/execution/HiveUDFSuite.scala | 40 ++++++++++++++++++ .../sql/hive/execution/SQLQuerySuite.scala | 42 +------------------ 2 files changed, 41 insertions(+), 41 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 8322b82e7733..383d41f907c6 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -596,6 +596,46 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { } } } + + test("UDTF") { + withUserDefinedFunction("udtf_count2" -> true) { + sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + } + + test("permanent UDTF") { + withUserDefinedFunction("udtf_count_temp" -> false) { + sql( + s""" + |CREATE FUNCTION udtf_count_temp + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").toURI}' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count_temp(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index a949e5e829e1..b5e34705dbc1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.TestUtils import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} @@ -98,46 +98,6 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(query1, Row("x1_y1") :: Row("x2_y2") :: Nil) } - test("UDTF") { - withUserDefinedFunction("udtf_count2" -> true) { - sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") - // The function source code can be found at: - // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF - sql( - """ - |CREATE TEMPORARY FUNCTION udtf_count2 - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - """.stripMargin) - - checkAnswer( - sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), - Row(97, 500) :: Row(97, 500) :: Nil) - - checkAnswer( - sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), - Row(3) :: Row(3) :: Nil) - } - } - - test("permanent UDTF") { - withUserDefinedFunction("udtf_count_temp" -> false) { - sql( - s""" - |CREATE FUNCTION udtf_count_temp - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").toURI}' - """.stripMargin) - - checkAnswer( - sql("SELECT key, cc FROM src LATERAL VIEW udtf_count_temp(value) dd AS cc"), - Row(97, 500) :: Row(97, 500) :: Nil) - - checkAnswer( - sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), - Row(3) :: Row(3) :: Nil) - } - } - test("SPARK-6835: udtf in lateral view") { val df = Seq((1, 1)).toDF("c1", "c2") df.createOrReplaceTempView("table1") From 7d9aabdaa2411ad6f51354f9fe4d89874337681b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Aug 2017 09:21:00 -0700 Subject: [PATCH 10/11] fix. --- .../apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 07660e12d938..999301c592dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -130,13 +130,6 @@ class SessionCatalog( if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } - /** - * Checks whether the Hive metastore is being used - */ - private def isUsingHiveMetastore: Boolean = { - conf.getConf(CATALOG_IMPLEMENTATION).toLowerCase(Locale.ROOT) == "hive" - } - private val tableRelationCache: Cache[QualifiedTableName, LogicalPlan] = { val cacheSize = conf.tableRelationCacheSize CacheBuilder.newBuilder().maximumSize(cacheSize).build[QualifiedTableName, LogicalPlan]() From 50224a7d2bf7db7f6b042681cb658e60967261b1 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 22 Aug 2017 09:24:44 -0700 Subject: [PATCH 11/11] fix. --- .../apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 2 +- .../scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 37d48a724a02..0908d68d2564 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1087,7 +1087,7 @@ class SessionCatalog( } /** - * Constructs a [[FunctionBuilder]] based on the provided class that represents a function. + * Constructs a [[Expression]] based on the provided class that represents a function. * * This performs reflection to decide what type of [[Expression]] to return in the builder. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 04d80926583a..b352bf6971ba 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -58,7 +58,9 @@ private[sql] class HiveSessionCatalog( functionResourceLoader) { /** - * Constructs a [[FunctionBuilder]] based on the provided class that represents a function. + * Constructs a [[Expression]] based on the provided class that represents a function. + * + * This performs reflection to decide what type of [[Expression]] to return in the builder. */ override def makeFunctionExpression( name: String,