From a329ebf7443ef66e76d21c7eeb5e052f95aa21ac Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 26 Jun 2017 11:58:43 +0200 Subject: [PATCH] Fixed a signature of XGBoostModel.predict Prior to this commit XGBoostModel.predict produced an RDD with an array of predictions for each partition, effectively changing the shape wrt the input RDD. A more natural contract for prediction API is that given an RDD it returns a new RDD with the same number of elements. This allows the users to easily match inputs with predictions. This commit removes one layer of nesting in XGBoostModel.predict output. Even though the change is clearly non-backward compatible, I still think it is well justified. See discussion in 06bd5dca for motivation. --- .../xgboost4j/scala/spark/XGBoostModel.scala | 17 ++++++++--------- .../scala/spark/XGBoostGeneralSuite.scala | 15 ++++++--------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala index a2ea4444378f..95f7fc9ea8ee 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostModel.scala @@ -59,19 +59,18 @@ abstract class XGBoostModel(protected var _booster: Booster) * * @param testSet test set represented as RDD */ - def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Array[Float]]] = { + def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Float]] = { import DataUtils._ val broadcastBooster = testSet.sparkContext.broadcast(_booster) testSet.mapPartitions { testSamples => val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString) Rabit.init(rabitEnv.asJava) - if (testSamples.hasNext) { + if (testSamples.nonEmpty) { val dMatrix = new DMatrix(new JDMatrix(testSamples, null)) try { - val res = broadcastBooster.value.predictLeaf(dMatrix) - Rabit.shutdown() - Iterator(res) + broadcastBooster.value.predictLeaf(dMatrix).iterator } finally { + Rabit.shutdown() dMatrix.delete() } } else { @@ -151,7 +150,7 @@ abstract class XGBoostModel(protected var _booster: Booster) * @param testSet test set represented as RDD * @param missingValue the specified value to represent the missing value */ - def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Array[Float]]] = { + def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = { val broadcastBooster = testSet.sparkContext.broadcast(_booster) testSet.mapPartitions { testSamples => val sampleArray = testSamples.toList @@ -169,7 +168,7 @@ abstract class XGBoostModel(protected var _booster: Booster) } val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue) try { - Iterator(broadcastBooster.value.predict(dMatrix)) + broadcastBooster.value.predict(dMatrix).iterator } finally { Rabit.shutdown() dMatrix.delete() @@ -188,7 +187,7 @@ abstract class XGBoostModel(protected var _booster: Booster) def predict( testSet: RDD[MLVector], useExternalCache: Boolean = false, - outputMargin: Boolean = false): RDD[Array[Array[Float]]] = { + outputMargin: Boolean = false): RDD[Array[Float]] = { val broadcastBooster = testSet.sparkContext.broadcast(_booster) val appName = testSet.context.appName testSet.mapPartitions { testSamples => @@ -205,7 +204,7 @@ abstract class XGBoostModel(protected var _booster: Booster) } val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName)) try { - Iterator(broadcastBooster.value.predict(dMatrix)) + broadcastBooster.value.predict(dMatrix).iterator } finally { Rabit.shutdown() dMatrix.delete() diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala index d4007401bf1d..83ee6da9af21 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/XGBoostGeneralSuite.scala @@ -252,7 +252,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { "objective" -> "binary:logistic") val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) val predRDD = xgBoostModel.predict(testRDD) - val predResult1 = predRDD.collect()(0) + val predResult1 = predRDD.collect() assert(testRDD.count() === predResult1.length) import DataUtils._ val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator)) @@ -273,14 +273,11 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { test("test prediction functionality with empty partition") { def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = { - val sampleList = new ListBuffer[SparkVector] - sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers) + sparkContext.getOrElse(sc).parallelize(List[SparkVector](), numWorkers) } val trainingRDD = buildTrainingRDD(sc) val testRDD = buildEmptyRDD() - val tempDir = Files.createTempDirectory("xgboosttest-") - val tempFile = Files.createTempFile(tempDir, "", "") val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1", "objective" -> "binary:logistic").toMap val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers) @@ -358,7 +355,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 1) val predRDD = xgBoostModel.predict(testRDD) - val predResult1: Array[Array[Float]] = predRDD.collect()(0) + val predResult1: Array[Array[Float]] = predRDD.collect() assert(testRDD.count() === predResult1.length) val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData) @@ -386,7 +383,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2) val predRDD = xgBoostModel.predict(testRDD) - val predResult1: Array[Array[Float]] = predRDD.collect()(0) + val predResult1: Array[Array[Float]] = predRDD.collect() assert(testRDD.count() === predResult1.length) } @@ -403,7 +400,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { val trainMargin = { XGBoost.trainWithRDD(trainRDD, paramMap, round = 1, nWorkers = 2) .predict(trainRDD.map(_.features), outputMargin = true) - .flatMap { _.flatten.iterator } + .map { case Array(m) => m } } val xgBoostModel = XGBoost.trainWithRDD( @@ -413,6 +410,6 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils { nWorkers = 2, baseMargin = trainMargin) - assert(testRDD.count() === xgBoostModel.predict(testRDD).first().length) + assert(testRDD.count() === xgBoostModel.predict(testRDD).count()) } }