diff --git a/build.sbt b/build.sbt index 825dd034..cdc7a2bd 100644 --- a/build.sbt +++ b/build.sbt @@ -14,7 +14,7 @@ sparkPackageName := "databricks/spark-sql-perf" // All Spark Packages need a license licenses := Seq("Apache-2.0" -> url("http://opensource.org/licenses/Apache-2.0")) -sparkVersion := "2.2.0" +sparkVersion := "2.3.0" sparkComponents ++= Seq("sql", "hive", "mllib") diff --git a/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoderEstimator.scala b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoderEstimator.scala new file mode 100644 index 00000000..312b5ab7 --- /dev/null +++ b/src/main/scala/com/databricks/spark/sql/perf/mllib/feature/OneHotEncoderEstimator.scala @@ -0,0 +1,34 @@ +package com.databricks.spark.sql.perf.mllib.feature + +import org.apache.spark.ml +import org.apache.spark.ml.linalg.Vector +import org.apache.spark.ml.PipelineStage +import org.apache.spark.sql._ + +import com.databricks.spark.sql.perf.mllib.OptionImplicits._ +import com.databricks.spark.sql.perf.mllib.data.DataGenerator +import com.databricks.spark.sql.perf.mllib.{BenchmarkAlgorithm, MLBenchContext, TestFromTraining} + +/** Object for testing OneHotEncoderEstimator performance */ +object OneHotEncoderEstimator extends BenchmarkAlgorithm with TestFromTraining with UnaryTransformer { + + override def trainingDataSet(ctx: MLBenchContext): DataFrame = { + import ctx.params._ + import ctx.sqlContext.implicits._ + + DataGenerator.generateMixedFeatures( + ctx.sqlContext, + numExamples, + ctx.seed(), + numPartitions, + Array.fill(1)(featureArity.get) + ).rdd.map { case Row(vec: Vector) => + vec(0) // extract the single generated double value for each row + }.toDF(inputCol) + } + + override def getPipelineStage(ctx: MLBenchContext): PipelineStage = { + new ml.feature.OneHotEncoderEstimator() + .setInputCols(Array(inputCol)) + } +}