Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions python/pyspark/sql/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,23 @@ def registerJavaFunction(self, name, javaClassName, returnType=None):
jdt = self.sparkSession._jsparkSession.parseDataType(returnType.json())
self.sparkSession._jsparkSession.udf().registerJava(name, javaClassName, jdt)

@ignore_unicode_prefix
@since(2.3)
def registerJavaUDAF(self, name, javaClassName):
"""Register a java UDAF so it can be used in SQL statements.

:param name: name of the UDAF
:param javaClassName: fully qualified name of java class

>>> sqlContext.registerJavaUDAF("javaUDAF",
... "test.org.apache.spark.sql.MyDoubleAvg")
>>> df = sqlContext.createDataFrame([(1, "a"),(2, "b"), (3, "a")],["id", "name"])
>>> df.registerTempTable("df")
>>> sqlContext.sql("SELECT name, javaUDAF(id) as avg from df group by name").collect()
[Row(name=u'b', avg=102.0), Row(name=u'a', avg=102.0)]
"""
self.sparkSession._jsparkSession.udf().registerJavaUDAF(name, javaClassName)

# TODO(andrew): delete this once we refactor things to take in SparkSession
def _inferSchema(self, rdd, samplingRatio=None):
"""
Expand Down Expand Up @@ -551,6 +568,12 @@ def __init__(self, sqlContext):
def register(self, name, f, returnType=StringType()):
return self.sqlContext.registerFunction(name, f, returnType)

def registerJavaFunction(self, name, javaClassName, returnType=None):
self.sqlContext.registerJavaFunction(name, javaClassName, returnType)

def registerJavaUDAF(self, name, javaClassName):
self.sqlContext.registerJavaUDAF(name, javaClassName)

register.__doc__ = SQLContext.registerFunction.__doc__


Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,16 @@ def test_udf_registration_returns_udf(self):
df.select(add_three("id").alias("plus_three")).collect()
)

def test_non_existed_udf(self):
spark = self.spark
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udf",
lambda: spark.udf.registerJavaFunction("udf1", "non_existed_udf"))

def test_non_existed_udaf(self):
spark = self.spark
self.assertRaisesRegexp(AnalysisException, "Can not load class non_existed_udaf",
lambda: spark.udf.registerJavaUDAF("udaf1", "non_existed_udaf"))

def test_multiLine_json(self):
people1 = self.spark.read.json("python/test_support/sql/people.json")
people_array = self.spark.read.json("python/test_support/sql/people_array.json",
Expand Down
33 changes: 27 additions & 6 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql

import java.io.IOException
import java.lang.reflect.{ParameterizedType, Type}

import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -456,9 +455,9 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
.map(_.asInstanceOf[ParameterizedType])
.filter(e => e.getRawType.isInstanceOf[Class[_]] && e.getRawType.asInstanceOf[Class[_]].getCanonicalName.startsWith("org.apache.spark.sql.api.java.UDF"))
if (udfInterfaces.length == 0) {
throw new IOException(s"UDF class ${className} doesn't implement any UDF interface")
throw new AnalysisException(s"UDF class ${className} doesn't implement any UDF interface")
} else if (udfInterfaces.length > 1) {
throw new IOException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}")
throw new AnalysisException(s"It is invalid to implement multiple UDF interfaces, UDF class ${className}")
} else {
try {
val udf = clazz.newInstance()
Expand Down Expand Up @@ -491,19 +490,41 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
case 21 => register(name, udf.asInstanceOf[UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 22 => register(name, udf.asInstanceOf[UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case 23 => register(name, udf.asInstanceOf[UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _]], returnType)
case n => logError(s"UDF class with ${n} type arguments is not supported ")
case n =>
throw new AnalysisException(s"UDF class with ${n} type arguments is not supported.")
}
} catch {
case e @ (_: InstantiationException | _: IllegalArgumentException) =>
logError(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
}
}
} catch {
case e: ClassNotFoundException => logError(s"Can not load class ${className}, please make sure it is on the classpath")
case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath")
}

}

/**
* Register a Java UDAF class using reflection, for use from pyspark
*
* @param name UDAF name
* @param className fully qualified class name of UDAF
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing @SInCE.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SInCE is needed for private function ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I did not notice it.

*/
private[sql] def registerJavaUDAF(name: String, className: String): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need returnType?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyspark side don't need returnType so I didn't use returnType here, and it is private function so should be open for adding returnType in future if necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at python side, why do we need returnType in registerJavaFunction?

Copy link
Contributor Author

@zjffdu zjffdu May 5, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is due to in scala side registerJava of UDFRegistration needs returnType. Yeah, it do looks like a little weird for python side to require returnType.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

registerJava accepts an optional return type, if not given, spark will try to infer it via reflection. do we really not need to do this for udaf?

Copy link
Member

@viirya viirya May 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UserDefinedAggregateFunction already defines its return type. So we don't need to specify it when registering an udfa.

try {
val clazz = Utils.classForName(className)
if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) {
throw new AnalysisException(s"class $className doesn't implement interface UserDefinedAggregateFunction")
}
val udaf = clazz.newInstance().asInstanceOf[UserDefinedAggregateFunction]
register(name, udaf)
} catch {
case e: ClassNotFoundException => throw new AnalysisException(s"Can not load class ${className}, please make sure it is on the classpath")
case e @ (_: InstantiationException | _: IllegalArgumentException) =>
throw new AnalysisException(s"Can not instantiate class ${className}, please make sure it has public non argument constructor")
}
}

/**
* Register a user-defined function with 1 arguments.
* @since 1.3.0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package test.org.apache.spark.sql;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;


public class JavaUDAFSuite {

private transient SparkSession spark;

@Before
public void setUp() {
spark = SparkSession.builder()
.master("local[*]")
.appName("testing")
.getOrCreate();
}

@After
public void tearDown() {
spark.stop();
spark = null;
}

@SuppressWarnings("unchecked")
@Test
public void udf1Test() {
spark.range(1, 10).toDF("value").registerTempTable("df");
spark.udf().registerJavaUDAF("myDoubleAvg", MyDoubleAvg.class.getName());
Row result = spark.sql("SELECT myDoubleAvg(value) as my_avg from df").head();
Assert.assertEquals(105.0, result.getDouble(0), 1.0e-6);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.sql.hive.aggregate;
package test.org.apache.spark.sql;

import java.util.ArrayList;
import java.util.List;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
* limitations under the License.
*/

package org.apache.spark.sql.hive.aggregate;
package test.org.apache.spark.sql;

import java.util.ArrayList;
import java.util.List;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;

/**
* An example {@link UserDefinedAggregateFunction} to calculate the sum of a
Expand Down
7 changes: 7 additions & 0 deletions sql/hive/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-tags_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.hive.test.TestHive$;
import org.apache.spark.sql.hive.aggregate.MyDoubleSum;
import test.org.apache.spark.sql.MyDoubleSum;

public class JavaDataFrameSuite {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we merge this suite with JavaDataFrameSuite in sql/core?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean move JavaDataFrameSuite to sql/core ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea

private transient SQLContext hc;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,19 @@ package org.apache.spark.sql.hive.execution
import scala.collection.JavaConverters._
import scala.util.Random

import test.org.apache.spark.sql.MyDoubleAvg
import test.org.apache.spark.sql.MyDoubleSum

import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types._


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we move this test suite to sql/core?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't add any test in this file. Or do you mean move AggregationQuerySuite.scala to sql/core ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea, seems there is no reason to leave this suite in sql/hive

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It depends on some hive stuff (TestHiveSingleton), so I guess it is intended to be put in sql/hive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when move it to sql/core, we can make it extend SharedSQLContext.

class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {

def inputSchema: StructType = schema
Expand Down