Skip to content

Commit 12ee9d6

Browse files
committed
Merge branch 'SPARK-28158' of github.com:uncleGen/spark into SPARK-28158
2 parents a8e2fb3 + 9cadbe4 commit 12ee9d6

File tree

4 files changed

+53
-3
lines changed

4 files changed

+53
-3
lines changed

mllib/pom.xml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,19 @@
7474
<type>test-jar</type>
7575
<scope>test</scope>
7676
</dependency>
77+
<dependency>
78+
<groupId>org.apache.spark</groupId>
79+
<artifactId>spark-hive_${scala.binary.version}</artifactId>
80+
<version>${project.version}</version>
81+
<scope>test</scope>
82+
</dependency>
83+
<dependency>
84+
<groupId>org.apache.spark</groupId>
85+
<artifactId>spark-hive_${scala.binary.version}</artifactId>
86+
<version>${project.version}</version>
87+
<type>test-jar</type>
88+
<scope>test</scope>
89+
</dependency>
7790
<dependency>
7891
<groupId>org.apache.spark</groupId>
7992
<artifactId>spark-graphx_${scala.binary.version}</artifactId>
1.75 KB
Binary file not shown.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
0 128:51 129:159 130:253 131:159 132:50 155:48 156:238 157:252 158:252 159:252 160:237 182:54 183:227 184:253 185:252 186:239 187:233 188:252 189:57 190:6 208:10 209:60 210:224 211:252 212:253 213:252 214:202 215:84 216:252 217:253 218:122 236:163 237:252 238:252 239:252 240:253 241:252 242:252 243:96 244:189 245:253 246:167 263:51 264:238 265:253 266:253 267:190 268:114 269:253 270:228 271:47 272:79 273:255 274:168 290:48 291:238 292:252 293:252 294:179 295:12 296:75 297:121 298:21 301:253 302:243 303:50 317:38 318:165 319:253 320:233 321:208 322:84 329:253 330:252 331:165 344:7 345:178 346:252 347:240 348:71 349:19 350:28 357:253 358:252 359:195 372:57 373:252 374:252 375:63 385:253 386:252 387:195 400:198 401:253 402:190 413:255 414:253 415:196 427:76 428:246 429:252 430:112 441:253 442:252 443:148 455:85 456:252 457:230 458:25 467:7 468:135 469:253 470:186 471:12 483:85 484:252 485:223 494:7 495:131 496:252 497:225 498:71 511:85 512:252 513:145 521:48 522:165 523:252 524:173 539:86 540:253 541:225 548:114 549:238 550:253 551:162 567:85 568:252 569:249 570:146 571:48 572:29 573:85 574:178 575:225 576:253 577:223 578:167 579:56 595:85 596:252 597:252 598:252 599:229 600:215 601:252 602:252 603:252 604:196 605:130 623:28 624:199 625:252 626:252 627:253 628:252 629:252 630:233 631:145 652:25 653:128 654:252 655:253 656:252 657:141 658:37

mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@
1717

1818
package org.apache.spark.ml.linalg
1919

20-
import org.apache.spark.SparkFunSuite
2120
import org.apache.spark.ml.feature.LabeledPoint
22-
import org.apache.spark.sql.catalyst.JavaTypeInference
21+
import org.apache.spark.sql.{QueryTest, Row, SparkSession}
22+
import org.apache.spark.sql.catalyst.{FunctionIdentifier, JavaTypeInference}
23+
import org.apache.spark.sql.hive.test.TestHive
2324
import org.apache.spark.sql.types._
2425

25-
class VectorUDTSuite extends SparkFunSuite {
26+
class VectorUDTSuite extends QueryTest {
2627

2728
test("preloaded VectorUDT") {
2829
val dv1 = Vectors.dense(Array.empty[Double])
@@ -44,4 +45,39 @@ class VectorUDTSuite extends SparkFunSuite {
4445
assert(dataType.asInstanceOf[StructType].fields.map(_.dataType)
4546
=== Seq(new VectorUDT, DoubleType))
4647
}
48+
49+
test("SPARK-28158 Hive UDFs supports UDT type") {
50+
val functionName = "Logistic_Regression"
51+
val sql = spark.sql _
52+
try {
53+
val df = spark.read.format("libsvm").options(Map("vectorType" -> "dense"))
54+
.load(TestHive.getHiveFile("test-data/libsvm/sample_libsvm_data.txt").getPath)
55+
df.createOrReplaceTempView("src")
56+
57+
// `Logistic_Regression` accepts features (with Vector type), and returns the
58+
// prediction value. To simplify the UDF implementation, the `Logistic_Regression`
59+
// will return 0.95d directly.
60+
sql(
61+
s"""
62+
|CREATE FUNCTION Logistic_Regression
63+
|AS 'org.apache.spark.sql.hive.LogisticRegressionUDF'
64+
|USING JAR '${TestHive.getHiveFile("TestLogRegUDF.jar").toURI}'
65+
""".stripMargin)
66+
67+
checkAnswer(
68+
sql("SELECT Logistic_Regression(features) FROM src"),
69+
Row(0.95) :: Nil)
70+
} catch {
71+
case cause: Throwable => throw cause
72+
} finally {
73+
// If the test failed part way, we don't want to mask the failure by failing to remove
74+
// temp tables that never got created.
75+
spark.sql(s"DROP FUNCTION IF EXISTS $functionName")
76+
assert(
77+
!spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)),
78+
s"Function $functionName should have been dropped. But, it still exists.")
79+
}
80+
}
81+
82+
override protected val spark: SparkSession = TestHive.sparkSession
4783
}

0 commit comments

Comments
 (0)