Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] XGBoost Spark integration refactor #3387

Merged
merged 9 commits into from
Jun 18, 2018
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.io.Source

import ml.dmlc.xgboost4j.scala.spark.{XGBoostEstimator, XGBoost}
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer}
Expand Down Expand Up @@ -160,10 +160,10 @@ object SparkModelTuningTool {
private def crossValidation(
xgboostParam: Map[String, Any],
trainingData: Dataset[_]): TrainValidationSplitModel = {
val xgbEstimator = new XGBoostEstimator(xgboostParam).setFeaturesCol("features").
val xgbEstimator = new XGBoostRegressor(xgboostParam).setFeaturesCol("features").
setLabelCol("logSales")
val paramGrid = new ParamGridBuilder()
.addGrid(xgbEstimator.round, Array(20, 50))
.addGrid(xgbEstimator.numRound, Array(20, 50))
.addGrid(xgbEstimator.eta, Array(0.1, 0.4))
.build()
val tv = new TrainValidationSplit()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package ml.dmlc.xgboost4j.scala.example.spark

import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.spark.XGBoost
import ml.dmlc.xgboost4j.scala.spark.XGBoostClassifier
import org.apache.spark.sql.SparkSession
import org.apache.spark.SparkConf

Expand Down Expand Up @@ -45,9 +45,10 @@ object SparkWithDataFrame {
val paramMap = List(
"eta" -> 0.1f,
"max_depth" -> 2,
"objective" -> "binary:logistic").toMap
val xgboostModel = XGBoost.trainWithDataFrame(
trainDF, paramMap, numRound, nWorkers = args(1).toInt, useExternalMemory = true)
"objective" -> "binary:logistic",
"num_round" -> numRound,
"nWorkers" -> args(1).toInt).toMap
val xgboostModel = new XGBoostClassifier(paramMap).fit(trainDF)
// xgboost-spark appends the column containing prediction results
xgboostModel.transform(testDF).show()
}
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package ml.dmlc.xgboost4j.scala.spark

import ml.dmlc.xgboost4j.scala.Booster
import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost}
import org.apache.commons.logging.LogFactory
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.SparkContext
Expand Down Expand Up @@ -63,9 +64,9 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
val version = versions.max
val fullPath = getPath(version)
logger.info(s"Start training from previous booster at $fullPath")
val model = XGBoost.loadModelFromHadoopFile(fullPath)(sc)
model.booster.booster.setVersion(version)
model.booster
val booster = SXGBoost.loadModel(fullPath)
booster.booster.setVersion(version)
booster
} else {
null
}
Expand All @@ -76,12 +77,12 @@ private[spark] class CheckpointManager(sc: SparkContext, checkpointPath: String)
*
* @param checkpoint the checkpoint to save as an XGBoostModel
*/
private[spark] def updateCheckpoint(checkpoint: XGBoostModel): Unit = {
private[spark] def updateCheckpoint(checkpoint: Booster): Unit = {
val fs = FileSystem.get(sc.hadoopConfiguration)
val prevModelPaths = getExistingVersions.map(version => new Path(getPath(version)))
val fullPath = getPath(checkpoint.version)
logger.info(s"Saving checkpoint model with version ${checkpoint.version} to $fullPath")
checkpoint.saveModelAsHadoopFile(fullPath)(sc)
val fullPath = getPath(checkpoint.getVersion)
logger.info(s"Saving checkpoint model with version ${checkpoint.getVersion} to $fullPath")
checkpoint.saveModel(fullPath)
prevModelPaths.foreach(path => fs.delete(path, true))
}

Expand Down
Loading