1717
1818package org .apache .spark .ml .linalg
1919
20- import org .apache .spark .SparkFunSuite
2120import org .apache .spark .ml .feature .LabeledPoint
22- import org .apache .spark .sql .catalyst .JavaTypeInference
21+ import org .apache .spark .sql .{QueryTest , Row , SparkSession }
22+ import org .apache .spark .sql .catalyst .{FunctionIdentifier , JavaTypeInference }
23+ import org .apache .spark .sql .hive .test .TestHive
2324import org .apache .spark .sql .types ._
2425
25- class VectorUDTSuite extends SparkFunSuite {
26+ class VectorUDTSuite extends QueryTest {
2627
2728 test(" preloaded VectorUDT" ) {
2829 val dv1 = Vectors .dense(Array .empty[Double ])
@@ -44,4 +45,39 @@ class VectorUDTSuite extends SparkFunSuite {
4445 assert(dataType.asInstanceOf [StructType ].fields.map(_.dataType)
4546 === Seq (new VectorUDT , DoubleType ))
4647 }
48+
49+ test(" SPARK-28158 Hive UDFs supports UDT type" ) {
50+ val functionName = " Logistic_Regression"
51+ val sql = spark.sql _
52+ try {
53+ val df = spark.read.format(" libsvm" ).options(Map (" vectorType" -> " dense" ))
54+ .load(TestHive .getHiveFile(" test-data/libsvm/sample_libsvm_data.txt" ).getPath)
55+ df.createOrReplaceTempView(" src" )
56+
57+ // `Logistic_Regression` accepts features (with Vector type), and returns the
58+ // prediction value. To simplify the UDF implementation, the `Logistic_Regression`
59+ // will return 0.95d directly.
60+ sql(
61+ s """
62+ |CREATE FUNCTION Logistic_Regression
63+ |AS 'org.apache.spark.sql.hive.LogisticRegressionUDF'
64+ |USING JAR ' ${TestHive .getHiveFile(" TestLogRegUDF.jar" ).toURI}'
65+ """ .stripMargin)
66+
67+ checkAnswer(
68+ sql(" SELECT Logistic_Regression(features) FROM src" ),
69+ Row (0.95 ) :: Nil )
70+ } catch {
71+ case cause : Throwable => throw cause
72+ } finally {
73+ // If the test failed part way, we don't want to mask the failure by failing to remove
74+ // temp tables that never got created.
75+ spark.sql(s " DROP FUNCTION IF EXISTS $functionName" )
76+ assert(
77+ ! spark.sessionState.catalog.functionExists(FunctionIdentifier (functionName)),
78+ s " Function $functionName should have been dropped. But, it still exists. " )
79+ }
80+ }
81+
82+ override protected val spark : SparkSession = TestHive .sparkSession
4783}
0 commit comments