From 22220efca3a21b87cfd914f297c8befff707e9ac Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 9 Mar 2017 15:06:50 +0800 Subject: [PATCH 1/4] [SPARK-19439][PYSPARK][SQL] PySpark's registerJavaFunction Should Support UDAFs --- python/pyspark/sql/context.py | 24 +++++++++++++++++++ .../apache/spark/sql/UDFRegistration.scala | 15 ++++++++++++ 2 files changed, 39 insertions(+) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 426f07cd9410..5218d204b875 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -232,6 +232,24 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json()) self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) + @ignore_unicode_prefix + @since(2.2) + def registerJavaUDAF(self, name, javaClassName): + """Register a java UDAF so it can be used in SQL statements. + + :param name: name of the UDF + :param javaClassName: fully qualified name of java class + + >>> sqlContext.registerJavaUDAF("javaUDAF", + ... "org.apache.spark.sql.hive.aggregate.MyDoubleAvg") + >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) + >>> df.registerTempTable("df") + >>> sqlContext.sql("SELECT name,javaUDAF(id) as avg from df group by name").collect() + [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] + """ + + self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) + # TODO(andrew): delete this once we refactor things to take in SparkSession def _inferSchema(self, rdd, samplingRatio=None): """ @@ -551,6 +569,12 @@ def __init__(self, sqlContext): def register(self, name, f, returnType=StringType()): return self.sqlContext.registerFunction(name, f, returnType) + def registerJavaFunction(self, name, javaClassName, returnType=None): + return self.sqlContext.registerJavaFunction(name, javaClassName, returnType) + + def registerJavaUDAF(self, name, javaClassName): + return self.sqlContext.registerJavaUDAF(name, javaClassName) + register.__doc__ = SQLContext.registerFunction.__doc__ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index ad01b889429c..c1647adde3e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -504,6 +504,21 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } + private[sql] def registerJavaUDAF(name: String, className: String): Unit = { + try { + val clazz = Utils.classForName(className) + if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) { + throw new IOException(s"class $className doesn't implement interface UserDefinedAggregateFunction") + } + val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction] + register(name, udaf) + } catch { + case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath") + case e @ (_: InstantiationException | _: IllegalArgumentException) => + logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + } + } + /** * Register a user-defined function with 1 arguments. * @since 1.3.0 From 6fd1d7504d3375aed7de4534bf3c092dd362e0d5 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Fri, 10 Mar 2017 08:28:12 +0800 Subject: [PATCH 2/4] add scala doc --- .../main/scala/org/apache/spark/sql/UDFRegistration.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index c1647adde3e7..927648ba9dd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -504,6 +504,12 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } + /** + * Register a Java UDAF class using reflection, for use from pyspark + * + * @param name UDAF name + * @param className fully qualified class name of UDAF + */ private[sql] def registerJavaUDAF(name: String, className: String): Unit = { try { val clazz = Utils.classForName(className) From e92a854f05711a8474a08525a3cb286fd91008c3 Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Thu, 6 Apr 2017 16:49:13 +0800 Subject: [PATCH 3/4] throw exception instead of logging error and add testcase --- python/pyspark/sql/context.py | 5 ++--- python/pyspark/sql/tests.py | 14 ++++++++++++++ .../org/apache/spark/sql/UDFRegistration.scala | 11 ++++++----- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 5218d204b875..9d73cb719c43 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -237,17 +237,16 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): def registerJavaUDAF(self, name, javaClassName): """Register a java UDAF so it can be used in SQL statements. - :param name: name of the UDF + :param name: name of the UDAF :param javaClassName: fully qualified name of java class >>> sqlContext.registerJavaUDAF("javaUDAF", ... "org.apache.spark.sql.hive.aggregate.MyDoubleAvg") >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) >>> df.registerTempTable("df") - >>> sqlContext.sql("SELECT name,javaUDAF(id) as avg from df group by name").collect() + >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() [Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)] """ - self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName) # TODO(andrew): delete this once we refactor things to take in SparkSession diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 326e8548a617..714a35a76c64 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -490,6 +490,20 @@ def test_udf_registration_returns_udf(self): df.select(add_three("id").alias("plus_three")).collect() ) + def test_non_existed_udf(self): + try: + self.spark.udf.registerJavaFunction("udf1", "non_existed_udf") + self.fail("should fail due to can not load java udf class") + except py4j.protocol.Py4JError as e: + self.assertTrue("Can not load class non_existed_udf" in str(e)) + + def test_non_existed_udaf(self): + try: + self.spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf") + self.fail("should fail due to can not load java udaf class") + except py4j.protocol.Py4JError as e: + self.assertTrue("Can not load class non_existed_udaf" in str(e)) + def test_multiLine_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") people_array = self.spark.read.json("python/test_support/sql/people_array.json", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 927648ba9dd4..04f00fd68268 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -491,15 +491,16 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) - case n => logError(s"UDF class with ${n} type arguments is not supported ") + case n => + throw new IOException(s"UDF class with ${n} type arguments is not supported.") } } catch { case e @ (_: InstantiationException | _: IllegalArgumentException) => - logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + throw new IOException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") } } } catch { - case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath") + case e: ClassNotFoundException => throw new IOException(s"Can not load class ${className}, please make sure it is on the classpath") } } @@ -519,9 +520,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction] register(name, udaf) } catch { - case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath") + case e: ClassNotFoundException => throw new IOException(s"Can not load class ${className}, please make sure it is on the classpath") case e @ (_: InstantiationException | _: IllegalArgumentException) => - logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + throw new IOException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") } } From ad5d2c99be23746c557264d51fcfcd480f2c848c Mon Sep 17 00:00:00 2001 From: Jeff Zhang Date: Wed, 12 Apr 2017 08:44:42 +0800 Subject: [PATCH 4/4] remove return --- python/pyspark/sql/context.py | 8 +-- python/pyspark/sql/tests.py | 16 ++---- .../apache/spark/sql/UDFRegistration.scala | 17 +++--- .../org/apache/spark/sql/JavaUDAFSuite.java | 55 +++++++++++++++++++ .../org/apache/spark/sql}/MyDoubleAvg.java | 2 +- .../org/apache/spark/sql}/MyDoubleSum.java | 8 +-- sql/hive/pom.xml | 7 +++ .../spark/sql/hive/JavaDataFrameSuite.java | 2 +- .../execution/AggregationQuerySuite.scala | 5 +- 9 files changed, 90 insertions(+), 30 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java rename sql/{hive/src/test/java/org/apache/spark/sql/hive/aggregate => core/src/test/java/test/org/apache/spark/sql}/MyDoubleAvg.java (99%) rename sql/{hive/src/test/java/org/apache/spark/sql/hive/aggregate => core/src/test/java/test/org/apache/spark/sql}/MyDoubleSum.java (98%) diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index 9d73cb719c43..c44ab247fd3d 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -233,7 +233,7 @@ def registerJavaFunction(self, name, javaClassName, returnType=None): self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt) @ignore_unicode_prefix - @since(2.2) + @since(2.3) def registerJavaUDAF(self, name, javaClassName): """Register a java UDAF so it can be used in SQL statements. @@ -241,7 +241,7 @@ def registerJavaUDAF(self, name, javaClassName): :param javaClassName: fully qualified name of java class >>> sqlContext.registerJavaUDAF("javaUDAF", - ... "org.apache.spark.sql.hive.aggregate.MyDoubleAvg") + ... "test.org.apache.spark.sql.MyDoubleAvg") >>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"]) >>> df.registerTempTable("df") >>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect() @@ -569,10 +569,10 @@ def register(self, name, f, returnType=StringType()): return self.sqlContext.registerFunction(name, f, returnType) def registerJavaFunction(self, name, javaClassName, returnType=None): - return self.sqlContext.registerJavaFunction(name, javaClassName, returnType) + self.sqlContext.registerJavaFunction(name, javaClassName, returnType) def registerJavaUDAF(self, name, javaClassName): - return self.sqlContext.registerJavaUDAF(name, javaClassName) + self.sqlContext.registerJavaUDAF(name, javaClassName) register.__doc__ = SQLContext.registerFunction.__doc__ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 714a35a76c64..a725827b0636 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -491,18 +491,14 @@ def test_udf_registration_returns_udf(self): ) def test_non_existed_udf(self): - try: - self.spark.udf.registerJavaFunction("udf1", "non_existed_udf") - self.fail("should fail due to can not load java udf class") - except py4j.protocol.Py4JError as e: - self.assertTrue("Can not load class non_existed_udf" in str(e)) + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf", + lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf")) def test_non_existed_udaf(self): - try: - self.spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf") - self.fail("should fail due to can not load java udaf class") - except py4j.protocol.Py4JError as e: - self.assertTrue("Can not load class non_existed_udaf" in str(e)) + spark = self.spark + self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf", + lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf")) def test_multiLine_json(self): people1 = self.spark.read.json("python/test_support/sql/people.json") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 04f00fd68268..8bdc0221888d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql -import java.io.IOException import java.lang.reflect.{ParameterizedType, Type} import scala.reflect.runtime.universe.TypeTag @@ -456,9 +455,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends .map(_.asInstanceOf[ParameterizedType]) .filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF")) if (udfInterfaces.length == 0) { - throw new IOException(s"UDF class ${className} doesn't implement any UDF interface") + throw new AnalysisException(s"UDF class ${className} doesn't implement any UDF interface") } else if (udfInterfaces.length > 1) { - throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") + throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}") } else { try { val udf = clazz.newInstance() @@ -492,15 +491,15 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType) case n => - throw new IOException(s"UDF class with ${n} type arguments is not supported.") + throw new AnalysisException(s"UDF class with ${n} type arguments is not supported.") } } catch { case e @ (_: InstantiationException | _: IllegalArgumentException) => - throw new IOException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") } } } catch { - case e: ClassNotFoundException => throw new IOException(s"Can not load class ${className}, please make sure it is on the classpath") + case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath") } } @@ -515,14 +514,14 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends try { val clazz = Utils.classForName(className) if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) { - throw new IOException(s"class $className doesn't implement interface UserDefinedAggregateFunction") + throw new AnalysisException(s"class $className doesn't implement interface UserDefinedAggregateFunction") } val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction] register(name, udaf) } catch { - case e: ClassNotFoundException => throw new IOException(s"Can not load class ${className}, please make sure it is on the classpath") + case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath") case e @ (_: InstantiationException | _: IllegalArgumentException) => - throw new IOException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") + throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor") } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java new file mode 100644 index 000000000000..ddbaa45a483c --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDAFSuite.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql; + +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SparkSession; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + + +public class JavaUDAFSuite { + + private transient SparkSession spark; + + @Before + public void setUp() { + spark = SparkSession.builder() + .master("local[*]") + .appName("testing") + .getOrCreate(); + } + + @After + public void tearDown() { + spark.stop(); + spark = null; + } + + @SuppressWarnings("unchecked") + @Test + public void udf1Test() { + spark.range(1, 10).toDF("value").registerTempTable("df"); + spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName()); + Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head(); + Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6); + } + +} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java similarity index 99% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java rename to sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java index ae0c097c362a..447a71d284fb 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleAvg.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleAvg.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.aggregate; +package test.org.apache.spark.sql; import java.util.ArrayList; import java.util.List; diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java similarity index 98% rename from sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java rename to sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java index d17fb3e5194f..93d20330c717 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/aggregate/MyDoubleSum.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/MyDoubleSum.java @@ -15,18 +15,18 @@ * limitations under the License. */ -package org.apache.spark.sql.hive.aggregate; +package test.org.apache.spark.sql; import java.util.ArrayList; import java.util.List; +import org.apache.spark.sql.Row; import org.apache.spark.sql.expressions.MutableAggregationBuffer; import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.Row; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; /** * An example {@link UserDefinedAggregateFunction} to calculate the sum of a diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 09dcc4055e00..f9462e79a69f 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -57,6 +57,13 @@ spark-sql_${scala.binary.version} ${project.version} + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-tags_${scala.binary.version} diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java index aefc9cc77da8..636ce10da373 100644 --- a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java +++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaDataFrameSuite.java @@ -31,7 +31,7 @@ import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; import static org.apache.spark.sql.functions.*; import org.apache.spark.sql.hive.test.TestHive$; -import org.apache.spark.sql.hive.aggregate.MyDoubleSum; +import test.org.apache.spark.sql.MyDoubleSum; public class JavaDataFrameSuite { private transient SQLContext hc; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 84f915977bd8..f245a79f805a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -20,16 +20,19 @@ package org.apache.spark.sql.hive.execution import scala.collection.JavaConverters._ import scala.util.Random +import test.org.apache.spark.sql.MyDoubleAvg +import test.org.apache.spark.sql.MyDoubleSum + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ + class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction { def inputSchema: StructType = schema