Skip to content

Commit

Permalink
Fixed a signature of XGBoostModel.predict (#2476)
Browse files Browse the repository at this point in the history
Prior to this commit XGBoostModel.predict produced an RDD with
an array of predictions for each partition, effectively changing
the shape wrt the input RDD. A more natural contract for prediction
API is that given an RDD it returns a new RDD with the same number
of elements. This allows the users to easily match inputs with
predictions.

This commit removes one layer of nesting in XGBoostModel.predict output.
Even though the change is clearly non-backward compatible, I still
think it is well justified. See discussion in 06bd5dc for motivation.
  • Loading branch information
superbobry authored and CodingCat committed Jul 3, 2017
1 parent ed8bc45 commit 8ceeb32
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,18 @@ abstract class XGBoostModel(protected var _booster: Booster)
*
* @param testSet test set represented as RDD
*/
def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Array[Float]]] = {
def predictLeaves(testSet: RDD[MLVector]): RDD[Array[Float]] = {
import DataUtils._
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
val rabitEnv = Map("DMLC_TASK_ID" -> TaskContext.getPartitionId().toString)
Rabit.init(rabitEnv.asJava)
if (testSamples.hasNext) {
if (testSamples.nonEmpty) {
val dMatrix = new DMatrix(new JDMatrix(testSamples, null))
try {
val res = broadcastBooster.value.predictLeaf(dMatrix)
Rabit.shutdown()
Iterator(res)
broadcastBooster.value.predictLeaf(dMatrix).iterator
} finally {
Rabit.shutdown()
dMatrix.delete()
}
} else {
Expand Down Expand Up @@ -151,7 +150,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
* @param testSet test set represented as RDD
* @param missingValue the specified value to represent the missing value
*/
def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Array[Float]]] = {
def predict(testSet: RDD[MLDenseVector], missingValue: Float): RDD[Array[Float]] = {
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
testSet.mapPartitions { testSamples =>
val sampleArray = testSamples.toList
Expand All @@ -169,7 +168,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
}
val dMatrix = new DMatrix(flatSampleArray, numRows, numColumns, missingValue)
try {
Iterator(broadcastBooster.value.predict(dMatrix))
broadcastBooster.value.predict(dMatrix).iterator
} finally {
Rabit.shutdown()
dMatrix.delete()
Expand All @@ -188,7 +187,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
def predict(
testSet: RDD[MLVector],
useExternalCache: Boolean = false,
outputMargin: Boolean = false): RDD[Array[Array[Float]]] = {
outputMargin: Boolean = false): RDD[Array[Float]] = {
val broadcastBooster = testSet.sparkContext.broadcast(_booster)
val appName = testSet.context.appName
testSet.mapPartitions { testSamples =>
Expand All @@ -205,7 +204,7 @@ abstract class XGBoostModel(protected var _booster: Booster)
}
val dMatrix = new DMatrix(new JDMatrix(testSamples, cacheFileName))
try {
Iterator(broadcastBooster.value.predict(dMatrix))
broadcastBooster.value.predict(dMatrix).iterator
} finally {
Rabit.shutdown()
dMatrix.delete()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
"objective" -> "binary:logistic")
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1 = predRDD.collect()(0)
val predResult1 = predRDD.collect()
assert(testRDD.count() === predResult1.length)
import DataUtils._
val predResult2 = xgBoostModel.booster.predict(new DMatrix(testSet.iterator))
Expand All @@ -273,14 +273,11 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {

test("test prediction functionality with empty partition") {
def buildEmptyRDD(sparkContext: Option[SparkContext] = None): RDD[SparkVector] = {
val sampleList = new ListBuffer[SparkVector]
sparkContext.getOrElse(sc).parallelize(sampleList, numWorkers)
sparkContext.getOrElse(sc).parallelize(List[SparkVector](), numWorkers)
}

val trainingRDD = buildTrainingRDD(sc)
val testRDD = buildEmptyRDD()
val tempDir = Files.createTempDirectory("xgboosttest-")
val tempFile = Files.createTempFile(tempDir, "", "")
val paramMap = List("eta" -> "1", "max_depth" -> "2", "silent" -> "1",
"objective" -> "binary:logistic").toMap
val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, numWorkers)
Expand Down Expand Up @@ -358,7 +355,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {

val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 1)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1: Array[Array[Float]] = predRDD.collect()(0)
val predResult1: Array[Array[Float]] = predRDD.collect()
assert(testRDD.count() === predResult1.length)

val avgMetric = xgBoostModel.eval(trainingRDD, "test", iter = 0, groupData = trainGroupData)
Expand Down Expand Up @@ -386,7 +383,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {

val xgBoostModel = XGBoost.trainWithRDD(trainingRDD, paramMap, 5, nWorkers = 2)
val predRDD = xgBoostModel.predict(testRDD)
val predResult1: Array[Array[Float]] = predRDD.collect()(0)
val predResult1: Array[Array[Float]] = predRDD.collect()
assert(testRDD.count() === predResult1.length)
}

Expand All @@ -403,7 +400,7 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
val trainMargin = {
XGBoost.trainWithRDD(trainRDD, paramMap, round = 1, nWorkers = 2)
.predict(trainRDD.map(_.features), outputMargin = true)
.flatMap { _.flatten.iterator }
.map { case Array(m) => m }
}

val xgBoostModel = XGBoost.trainWithRDD(
Expand All @@ -413,6 +410,6 @@ class XGBoostGeneralSuite extends SharedSparkContext with Utils {
nWorkers = 2,
baseMargin = trainMargin)

assert(testRDD.count() === xgBoostModel.predict(testRDD).first().length)
assert(testRDD.count() === xgBoostModel.predict(testRDD).count())
}
}

0 comments on commit 8ceeb32

Please sign in to comment.