From ded13096f2213ec561f22478ed2a3b4f00d541ed Mon Sep 17 00:00:00 2001 From: uncleGen Date: Tue, 25 Jun 2019 14:18:19 +0800 Subject: [PATCH 1/6] Hive UDFs supports UDT type --- .../org/apache/spark/sql/hive/HiveInspectors.scala | 5 +++++ .../apache/spark/sql/hive/HiveInspectorSuite.scala | 11 ++++++++++- 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 33b5bcefd853..5b627b816413 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -787,6 +787,9 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardStructObjectInspector( java.util.Arrays.asList(fields.map(f => f.name) : _*), java.util.Arrays.asList(fields.map(f => toInspector(f.dataType)) : _*)) + case _: UserDefinedType[_] => + val sqlType = dataType.asInstanceOf[UserDefinedType[_]].sqlType + toInspector(sqlType) } /** @@ -849,6 +852,8 @@ private[hive] trait HiveInspectors { } case Literal(_, dt: StructType) => toInspector(dt) + case Literal(_, dt: UserDefinedType[_]) => + toInspector(dt.sqlType) // We will enumerate all of the possible constant expressions, throw exception if we missed case Literal(_, dt) => sys.error(s"Hive doesn't support the constant type [$dt].") // ideally, we don't test the foldable here(but in optimizer), however, some of the diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index c300660458fd..4184d6c8580a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectIn import org.apache.hadoop.io.LongWritable import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Row, TestUserClassUDT} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} @@ -214,6 +214,15 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { }) } + test("wrap / unwrap UDT Type") { + case object TestUserClassUDT extends TestUserClassUDT + + val dt = TestUserClassUDT + val d = 1 + checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) + checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) + } + test("wrap / unwrap Struct Type") { val dt = StructType(dataTypes.zipWithIndex.map { case (t, idx) => StructField(s"c_$idx", t) From 8bf8d43366418ef7ee0484ac6d6cf133033bdc87 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Mon, 15 Jul 2019 19:50:41 +0800 Subject: [PATCH 2/6] add a vector udt test --- mllib/pom.xml | 13 ++++++ mllib/src/test/resources/TestLogRegUDF.jar | Bin 0 -> 1794 bytes .../test-data/libsvm/sample_libsvm_data.txt | 1 + .../spark/ml/linalg/VectorUDTSuite.scala | 42 ++++++++++++++++-- 4 files changed, 53 insertions(+), 3 deletions(-) create mode 100644 mllib/src/test/resources/TestLogRegUDF.jar create mode 100644 mllib/src/test/resources/test-data/libsvm/sample_libsvm_data.txt diff --git a/mllib/pom.xml b/mllib/pom.xml index 11769ef548d7..a1ecc15d2117 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -74,6 +74,19 @@ test-jar test + + org.apache.spark + spark-hive_${scala.binary.version} + ${project.version} + test + + + org.apache.spark + spark-hive_${scala.binary.version} + ${project.version} + test-jar + test + org.apache.spark spark-graphx_${scala.binary.version} diff --git a/mllib/src/test/resources/TestLogRegUDF.jar b/mllib/src/test/resources/TestLogRegUDF.jar new file mode 100644 index 0000000000000000000000000000000000000000..d0a44c5aa0de1a6663d414701e3110a84c7593a5 GIT binary patch literal 1794 zcmWIWW@Zs#;Nak3xIO2+9|IELWME3R9;c=}!}UB$u~ z@j~*5^ejouj345vYR{KwSv(gPeIl+pjp1tcr=pKVj9^!+oxWkC3(yW%AO^XDi-7^? zvPo!OU;)bJ7p3dtR>2KbkywzJoRNx0p&+Wl;)2AYY&^**!Jh#4YS)5k)_vb(612*rLOgK~_b5WM1as4pnicmw5@B^Ncf9o|a_uJ7E8hE3CoWz8K;q7=2VTE^pOE-|>Qd^A z%qx$=c2=$4{eDe=-$BEruRR~P8cZ#SeR*SjQDxWcPlvtoPJTb~S@H9-qb0M1XJ%ZU zuXu85$S!B9$pZM=wlN+?#dV zX>G)vQdZ6z7W))JZbWTdGP~)+r3>=y#-gQf4AiHl%zj^R_t7@Vm&%d+x1V1%IDR3S zOJDu&s-xdmY+jOU)HC^6)sYKT9xf@1q9jr?MW=3%{MG+q;iI+Irl}o!zRb>fazbp@ zr@NOdTGyHQDV^Q$N&a1P+fS3cN7LF`O~eXke|gJ(e~o+kh3GfsZ7;)CpED@m(|tp< z&2IsF;Ay!pQiqFXAOG#YQuN_RrSCibweavd#%~Y%y=s!^u0Jdl`}Sv9b=Z7AZrhgl ze@3xi@4VS+9EYpAr*Dz__Ofuvm1^6sbv%kZv+7HaMs+3{OwNAuLSREnVAYP`FB6PQ zew_7KZLxNxdFGpCca?5Ro?ZDSe1p#Krau2IuO0gH7vGeB{N$$h>XHr@`CrV(7^2VFWC9lh?ji|c!eJm2Qb6I>0V<*p;3<#^(g829aBBn=SqQ+#h@ufz zfZ^5-D#8%J7MpfRfri@%P?3fJ1-OiW7I26d0~K+|u?Z^T5MTvfbMO>+0p6@^AY(Xy MP!yONu7lG70KH67c>n+a literal 0 HcmV?d00001 diff --git a/mllib/src/test/resources/test-data/libsvm/sample_libsvm_data.txt b/mllib/src/test/resources/test-data/libsvm/sample_libsvm_data.txt new file mode 100644 index 000000000000..8d325e27fbb3 --- /dev/null +++ b/mllib/src/test/resources/test-data/libsvm/sample_libsvm_data.txt @@ -0,0 +1 @@ +0 128:51 129:159 130:253 131:159 132:50 155:48 156:238 157:252 158:252 159:252 160:237 182:54 183:227 184:253 185:252 186:239 187:233 188:252 189:57 190:6 208:10 209:60 210:224 211:252 212:253 213:252 214:202 215:84 216:252 217:253 218:122 236:163 237:252 238:252 239:252 240:253 241:252 242:252 243:96 244:189 245:253 246:167 263:51 264:238 265:253 266:253 267:190 268:114 269:253 270:228 271:47 272:79 273:255 274:168 290:48 291:238 292:252 293:252 294:179 295:12 296:75 297:121 298:21 301:253 302:243 303:50 317:38 318:165 319:253 320:233 321:208 322:84 329:253 330:252 331:165 344:7 345:178 346:252 347:240 348:71 349:19 350:28 357:253 358:252 359:195 372:57 373:252 374:252 375:63 385:253 386:252 387:195 400:198 401:253 402:190 413:255 414:253 415:196 427:76 428:246 429:252 430:112 441:253 442:252 443:148 455:85 456:252 457:230 458:25 467:7 468:135 469:253 470:186 471:12 483:85 484:252 485:223 494:7 495:131 496:252 497:225 498:71 511:85 512:252 513:145 521:48 522:165 523:252 524:173 539:86 540:253 541:225 548:114 549:238 550:253 551:162 567:85 568:252 569:249 570:146 571:48 572:29 573:85 574:178 575:225 576:253 577:223 578:167 579:56 595:85 596:252 597:252 598:252 599:229 600:215 601:252 602:252 603:252 604:196 605:130 623:28 624:199 625:252 626:252 627:253 628:252 629:252 630:233 631:145 652:25 653:128 654:252 655:253 656:252 657:141 658:37 \ No newline at end of file diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala index 67c64f762b25..09af0c621958 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.ml.linalg -import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.sql.catalyst.JavaTypeInference +import org.apache.spark.sql.{QueryTest, Row, SparkSession} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, JavaTypeInference} +import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.types._ -class VectorUDTSuite extends SparkFunSuite { +class VectorUDTSuite extends QueryTest { test("preloaded VectorUDT") { val dv1 = Vectors.dense(Array.empty[Double]) @@ -44,4 +45,39 @@ class VectorUDTSuite extends SparkFunSuite { assert(dataType.asInstanceOf[StructType].fields.map(_.dataType) === Seq(new VectorUDT, DoubleType)) } + + test("SPARK-28158 Hive UDFs supports UDT type") { + val functionName = "Logistic_Regression" + val sql = spark.sql _ + try { + val df = spark.read.format("libsvm").options(Map("vectorType" -> "dense")) + .load(TestHive.getHiveFile("test-data/libsvm/sample_libsvm_data.txt").getPath) + df.createOrReplaceTempView("src") + + // `Logistic_Regression` accepts features (with Vector type), and returns the + // prediction value. To simplify the UDF implementation, the `Logistic_Regression` + // will return 0.95d directly. + sql( + s""" + |CREATE FUNCTION Logistic_Regression + |AS 'org.apache.spark.sql.hive.LogisticRegressionUDF' + |USING JAR '${TestHive.getHiveFile("TestLogRegUDF.jar").toURI}' + """.stripMargin) + + checkAnswer( + sql("SELECT Logistic_Regression(features) FROM src"), + Row(0.95) :: Nil) + } catch { + case cause: Throwable => throw cause + } finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + spark.sql(s"DROP FUNCTION IF EXISTS $functionName") + assert( + !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + s"Function $functionName should have been dropped. But, it still exists.") + } + } + + override protected val spark: SparkSession = TestHive.sparkSession } From ba69b45c93dc757f35d2d922d822e31b6ef9de2d Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 31 Jul 2019 19:06:52 +0800 Subject: [PATCH 3/6] save --- .../org/apache/spark/ml/linalg/VectorUDTSuite.scala | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala index 09af0c621958..dd0203b5ceac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala @@ -25,6 +25,11 @@ import org.apache.spark.sql.types._ class VectorUDTSuite extends QueryTest { + override def afterAll(): Unit = { + this.afterAll() + spark.stop() + } + test("preloaded VectorUDT") { val dv1 = Vectors.dense(Array.empty[Double]) val dv2 = Vectors.dense(1.0, 2.0) @@ -48,7 +53,6 @@ class VectorUDTSuite extends QueryTest { test("SPARK-28158 Hive UDFs supports UDT type") { val functionName = "Logistic_Regression" - val sql = spark.sql _ try { val df = spark.read.format("libsvm").options(Map("vectorType" -> "dense")) .load(TestHive.getHiveFile("test-data/libsvm/sample_libsvm_data.txt").getPath) @@ -57,7 +61,7 @@ class VectorUDTSuite extends QueryTest { // `Logistic_Regression` accepts features (with Vector type), and returns the // prediction value. To simplify the UDF implementation, the `Logistic_Regression` // will return 0.95d directly. - sql( + spark.sql( s""" |CREATE FUNCTION Logistic_Regression |AS 'org.apache.spark.sql.hive.LogisticRegressionUDF' @@ -65,7 +69,7 @@ class VectorUDTSuite extends QueryTest { """.stripMargin) checkAnswer( - sql("SELECT Logistic_Regression(features) FROM src"), + spark.sql("SELECT Logistic_Regression(features) FROM src"), Row(0.95) :: Nil) } catch { case cause: Throwable => throw cause @@ -76,6 +80,7 @@ class VectorUDTSuite extends QueryTest { assert( !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), s"Function $functionName should have been dropped. But, it still exists.") + spark.stop() } } From 7138d6aee16884a627171d57c6552d0bbbeb262f Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 25 Sep 2019 14:46:43 +0800 Subject: [PATCH 4/6] fix comments --- mllib/pom.xml | 13 --- mllib/src/test/resources/TestLogRegUDF.jar | Bin 1794 -> 0 bytes .../test-data/libsvm/sample_libsvm_data.txt | 1 - .../spark/ml/linalg/VectorUDTSuite.scala | 47 +--------- .../spark/sql/hive/HiveInspectorSuite.scala | 7 +- .../sql/hive/HiveUserDefinedTypeSuite.scala | 84 ++++++++++++++++++ 6 files changed, 89 insertions(+), 63 deletions(-) delete mode 100644 mllib/src/test/resources/TestLogRegUDF.jar delete mode 100644 mllib/src/test/resources/test-data/libsvm/sample_libsvm_data.txt create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala diff --git a/mllib/pom.xml b/mllib/pom.xml index a1ecc15d2117..11769ef548d7 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -74,19 +74,6 @@ test-jar test - - org.apache.spark - spark-hive_${scala.binary.version} - ${project.version} - test - - - org.apache.spark - spark-hive_${scala.binary.version} - ${project.version} - test-jar - test - org.apache.spark spark-graphx_${scala.binary.version} diff --git a/mllib/src/test/resources/TestLogRegUDF.jar b/mllib/src/test/resources/TestLogRegUDF.jar deleted file mode 100644 index d0a44c5aa0de1a6663d414701e3110a84c7593a5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1794 zcmWIWW@Zs#;Nak3xIO2+9|IELWME3R9;c=}!}UB$u~ z@j~*5^ejouj345vYR{KwSv(gPeIl+pjp1tcr=pKVj9^!+oxWkC3(yW%AO^XDi-7^? zvPo!OU;)bJ7p3dtR>2KbkywzJoRNx0p&+Wl;)2AYY&^**!Jh#4YS)5k)_vb(612*rLOgK~_b5WM1as4pnicmw5@B^Ncf9o|a_uJ7E8hE3CoWz8K;q7=2VTE^pOE-|>Qd^A z%qx$=c2=$4{eDe=-$BEruRR~P8cZ#SeR*SjQDxWcPlvtoPJTb~S@H9-qb0M1XJ%ZU zuXu85$S!B9$pZM=wlN+?#dV zX>G)vQdZ6z7W))JZbWTdGP~)+r3>=y#-gQf4AiHl%zj^R_t7@Vm&%d+x1V1%IDR3S zOJDu&s-xdmY+jOU)HC^6)sYKT9xf@1q9jr?MW=3%{MG+q;iI+Irl}o!zRb>fazbp@ zr@NOdTGyHQDV^Q$N&a1P+fS3cN7LF`O~eXke|gJ(e~o+kh3GfsZ7;)CpED@m(|tp< z&2IsF;Ay!pQiqFXAOG#YQuN_RrSCibweavd#%~Y%y=s!^u0Jdl`}Sv9b=Z7AZrhgl ze@3xi@4VS+9EYpAr*Dz__Ofuvm1^6sbv%kZv+7HaMs+3{OwNAuLSREnVAYP`FB6PQ zew_7KZLxNxdFGpCca?5Ro?ZDSe1p#Krau2IuO0gH7vGeB{N$$h>XHr@`CrV(7^2VFWC9lh?ji|c!eJm2Qb6I>0V<*p;3<#^(g829aBBn=SqQ+#h@ufz zfZ^5-D#8%J7MpfRfri@%P?3fJ1-OiW7I26d0~K+|u?Z^T5MTvfbMO>+0p6@^AY(Xy MP!yONu7lG70KH67c>n+a diff --git a/mllib/src/test/resources/test-data/libsvm/sample_libsvm_data.txt b/mllib/src/test/resources/test-data/libsvm/sample_libsvm_data.txt deleted file mode 100644 index 8d325e27fbb3..000000000000 --- a/mllib/src/test/resources/test-data/libsvm/sample_libsvm_data.txt +++ /dev/null @@ -1 +0,0 @@ -0 128:51 129:159 130:253 131:159 132:50 155:48 156:238 157:252 158:252 159:252 160:237 182:54 183:227 184:253 185:252 186:239 187:233 188:252 189:57 190:6 208:10 209:60 210:224 211:252 212:253 213:252 214:202 215:84 216:252 217:253 218:122 236:163 237:252 238:252 239:252 240:253 241:252 242:252 243:96 244:189 245:253 246:167 263:51 264:238 265:253 266:253 267:190 268:114 269:253 270:228 271:47 272:79 273:255 274:168 290:48 291:238 292:252 293:252 294:179 295:12 296:75 297:121 298:21 301:253 302:243 303:50 317:38 318:165 319:253 320:233 321:208 322:84 329:253 330:252 331:165 344:7 345:178 346:252 347:240 348:71 349:19 350:28 357:253 358:252 359:195 372:57 373:252 374:252 375:63 385:253 386:252 387:195 400:198 401:253 402:190 413:255 414:253 415:196 427:76 428:246 429:252 430:112 441:253 442:252 443:148 455:85 456:252 457:230 458:25 467:7 468:135 469:253 470:186 471:12 483:85 484:252 485:223 494:7 495:131 496:252 497:225 498:71 511:85 512:252 513:145 521:48 522:165 523:252 524:173 539:86 540:253 541:225 548:114 549:238 550:253 551:162 567:85 568:252 569:249 570:146 571:48 572:29 573:85 574:178 575:225 576:253 577:223 578:167 579:56 595:85 596:252 597:252 598:252 599:229 600:215 601:252 602:252 603:252 604:196 605:130 623:28 624:199 625:252 626:252 627:253 628:252 629:252 630:233 631:145 652:25 653:128 654:252 655:253 656:252 657:141 658:37 \ No newline at end of file diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala index dd0203b5ceac..67c64f762b25 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala @@ -17,18 +17,12 @@ package org.apache.spark.ml.linalg +import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.sql.{QueryTest, Row, SparkSession} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, JavaTypeInference} -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.catalyst.JavaTypeInference import org.apache.spark.sql.types._ -class VectorUDTSuite extends QueryTest { - - override def afterAll(): Unit = { - this.afterAll() - spark.stop() - } +class VectorUDTSuite extends SparkFunSuite { test("preloaded VectorUDT") { val dv1 = Vectors.dense(Array.empty[Double]) @@ -50,39 +44,4 @@ class VectorUDTSuite extends QueryTest { assert(dataType.asInstanceOf[StructType].fields.map(_.dataType) === Seq(new VectorUDT, DoubleType)) } - - test("SPARK-28158 Hive UDFs supports UDT type") { - val functionName = "Logistic_Regression" - try { - val df = spark.read.format("libsvm").options(Map("vectorType" -> "dense")) - .load(TestHive.getHiveFile("test-data/libsvm/sample_libsvm_data.txt").getPath) - df.createOrReplaceTempView("src") - - // `Logistic_Regression` accepts features (with Vector type), and returns the - // prediction value. To simplify the UDF implementation, the `Logistic_Regression` - // will return 0.95d directly. - spark.sql( - s""" - |CREATE FUNCTION Logistic_Regression - |AS 'org.apache.spark.sql.hive.LogisticRegressionUDF' - |USING JAR '${TestHive.getHiveFile("TestLogRegUDF.jar").toURI}' - """.stripMargin) - - checkAnswer( - spark.sql("SELECT Logistic_Regression(features) FROM src"), - Row(0.95) :: Nil) - } catch { - case cause: Throwable => throw cause - } finally { - // If the test failed part way, we don't want to mask the failure by failing to remove - // temp tables that never got created. - spark.sql(s"DROP FUNCTION IF EXISTS $functionName") - assert( - !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), - s"Function $functionName should have been dropped. But, it still exists.") - spark.stop() - } - } - - override protected val spark: SparkSession = TestHive.sparkSession } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 4184d6c8580a..5912992694e8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -215,11 +215,8 @@ class HiveInspectorSuite extends SparkFunSuite with HiveInspectors { } test("wrap / unwrap UDT Type") { - case object TestUserClassUDT extends TestUserClassUDT - - val dt = TestUserClassUDT - val d = 1 - checkValue(d, unwrap(wrap(d, toInspector(dt), dt), toInspector(dt))) + val dt = new TestUserClassUDT + checkValue(1, unwrap(wrap(1, toInspector(dt), dt), toInspector(dt))) checkValue(null, unwrap(wrap(null, toInspector(dt), dt), toInspector(dt))) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala new file mode 100644 index 000000000000..422363663657 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.lang + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StandardListObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory + +import org.apache.spark.sql.{QueryTest, RandomDataGenerator, Row, SparkSession} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT} +import org.apache.spark.sql.types._ + +class HiveUserDefinedTypeSuite extends QueryTest { + private[this] val functionClass = classOf[org.apache.spark.sql.hive.TestUDF].getCanonicalName + + override def afterAll(): Unit = { + spark.stop() + } + + test("Support UDT in Hive UDF") { + val functionName = "get_point_x" + val sql = spark.sql _ + try { + val schema = new StructType().add("point", new ExamplePointUDT) + val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get + val input = inputGenerator.apply().asInstanceOf[Row] + val df = spark.createDataFrame(Array(input).toList.asJava, schema) + df.createOrReplaceTempView("src") + sql(s"CREATE FUNCTION $functionName AS '$functionClass'") + + checkAnswer( + sql(s"SELECT $functionName(point) FROM src"), + Row(input.getAs[ExamplePoint](0).x)) + } catch { + case cause: Throwable => throw cause + } finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + spark.sql(s"DROP FUNCTION IF EXISTS $functionName") + assert( + !spark.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + s"Function $functionName should have been dropped. But, it still exists.") + } + } + + override protected def spark: SparkSession = TestHive.sparkSession +} + +class TestUDF extends GenericUDF { + var data: StandardListObjectInspector = _ + + override def getDisplayString(children: Array[String]): String = "get_point_x" + + override def initialize(arguments: Array[ObjectInspector]): ObjectInspector = { + data = arguments(0).asInstanceOf[StandardListObjectInspector] + PrimitiveObjectInspectorFactory.javaDoubleObjectInspector + } + + override def evaluate(arguments: Array[GenericUDF.DeferredObject]): AnyRef = { + val point = data.getList(arguments(0).get()) + new lang.Double(point.get(0).asInstanceOf[Double]) + } +} From 4a4e75cd7f7ce1ae36c1e458f7e72f5079fd2557 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Wed, 25 Sep 2019 20:09:26 +0800 Subject: [PATCH 5/6] fix ut failure --- .../org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala | 4 ---- 1 file changed, 4 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala index 422363663657..1ce93bc30886 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala @@ -34,10 +34,6 @@ import org.apache.spark.sql.types._ class HiveUserDefinedTypeSuite extends QueryTest { private[this] val functionClass = classOf[org.apache.spark.sql.hive.TestUDF].getCanonicalName - override def afterAll(): Unit = { - spark.stop() - } - test("Support UDT in Hive UDF") { val functionName = "get_point_x" val sql = spark.sql _ From ab0282eb4e29afed5009a043774809f612d475b0 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Thu, 24 Oct 2019 19:21:40 +0800 Subject: [PATCH 6/6] fix comments --- .../sql/hive/HiveUserDefinedTypeSuite.scala | 25 +++++++------------ 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala index 1ce93bc30886..bddb7688fe96 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUserDefinedTypeSuite.scala @@ -17,39 +17,34 @@ package org.apache.spark.sql.hive -import java.lang - import scala.collection.JavaConverters._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, StandardListObjectInspector} import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.spark.sql.{QueryTest, RandomDataGenerator, Row, SparkSession} +import org.apache.spark.sql.{QueryTest, RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.FunctionIdentifier -import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT} -import org.apache.spark.sql.types._ +import org.apache.spark.sql.types.StructType -class HiveUserDefinedTypeSuite extends QueryTest { - private[this] val functionClass = classOf[org.apache.spark.sql.hive.TestUDF].getCanonicalName +class HiveUserDefinedTypeSuite extends QueryTest with TestHiveSingleton { + private val functionClass = classOf[org.apache.spark.sql.hive.TestUDF].getCanonicalName test("Support UDT in Hive UDF") { val functionName = "get_point_x" - val sql = spark.sql _ try { val schema = new StructType().add("point", new ExamplePointUDT) val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get val input = inputGenerator.apply().asInstanceOf[Row] val df = spark.createDataFrame(Array(input).toList.asJava, schema) df.createOrReplaceTempView("src") - sql(s"CREATE FUNCTION $functionName AS '$functionClass'") + spark.sql(s"CREATE FUNCTION $functionName AS '$functionClass'") checkAnswer( - sql(s"SELECT $functionName(point) FROM src"), + spark.sql(s"SELECT $functionName(point) FROM src"), Row(input.getAs[ExamplePoint](0).x)) - } catch { - case cause: Throwable => throw cause } finally { // If the test failed part way, we don't want to mask the failure by failing to remove // temp tables that never got created. @@ -59,12 +54,10 @@ class HiveUserDefinedTypeSuite extends QueryTest { s"Function $functionName should have been dropped. But, it still exists.") } } - - override protected def spark: SparkSession = TestHive.sparkSession } class TestUDF extends GenericUDF { - var data: StandardListObjectInspector = _ + private var data: StandardListObjectInspector = _ override def getDisplayString(children: Array[String]): String = "get_point_x" @@ -75,6 +68,6 @@ class TestUDF extends GenericUDF { override def evaluate(arguments: Array[GenericUDF.DeferredObject]): AnyRef = { val point = data.getList(arguments(0).get()) - new lang.Double(point.get(0).asInstanceOf[Double]) + new java.lang.Double(point.get(0).asInstanceOf[Double]) } }