Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[jvm-packages] throw exception when tree_method=approx and device=cuda #9478

Merged
merged 3 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s

private val overridedParams = overrideParams(rawParams, sc)

validateSparkSslConf()

/**
* Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true).
* If so, throw an exception unless this safety measure has been explicitly overridden
* via conf `xgboost.spark.ignoreSsl`.
*/
private def validateSparkSslConf: Unit = {
private def validateSparkSslConf(): Unit = {
val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) =
SparkSession.getActiveSession match {
case Some(ss) =>
Expand Down Expand Up @@ -148,55 +150,59 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s
overridedParams
}

/**
* The Map parameters accepted by estimator's constructor may have string type,
* Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these
* kind of parameters into the correct type in the function.
*
* @return XGBoostExecutionParams
*/
def buildXGBRuntimeParams: XGBoostExecutionParams = {

val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]
if (obj != null) {
require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
"is not defined, you have to specify the objective type as classification or regression" +
" with a customized objective function")
}

var trainTestRatio = 1.0
if (overridedParams.contains("train_test_ratio")) {
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
"'eval_set_names'")
trainTestRatio = overridedParams.get("train_test_ratio").get.asInstanceOf[Double]
}

val nWorkers = overridedParams("num_workers").asInstanceOf[Int]
val round = overridedParams("num_round").asInstanceOf[Int]
val useExternalMemory = overridedParams
.getOrElse("use_external_memory", false).asInstanceOf[Boolean]
val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait]
val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait]

val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float]
val allowNonZeroForMissing = overridedParams
.getOrElse("allow_non_zero_for_missing", false)
.asInstanceOf[Boolean]
validateSparkSslConf
var treeMethod: Option[String] = None
if (overridedParams.contains("tree_method")) {
require(overridedParams("tree_method") == "hist" ||
overridedParams("tree_method") == "approx" ||
overridedParams("tree_method") == "auto" ||
overridedParams("tree_method") == "gpu_hist", "xgboost4j-spark only supports tree_method" +
" as 'hist', 'approx', 'gpu_hist', and 'auto'")
treeMethod = Some(overridedParams("tree_method").asInstanceOf[String])
}

val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString)
// back-compatible with "gpu_hist"
val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) {
Some("cuda")
} else overridedParams.get("device").map(_.toString)

if (overridedParams.contains("train_test_ratio")) {
logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" +
" pass a training and multiple evaluation datasets by passing 'eval_sets' and " +
"'eval_set_names'")
}
require(nWorkers > 0, "you must specify more than 0 workers")
if (obj != null) {
require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " +
"is not defined, you have to specify the objective type as classification or regression" +
" with a customized objective function")
}
require(!(treeMethod.exists(_ == "approx") && device.exists(_ == "cuda")),
"The tree method \"approx\" is not yet supported for Spark GPU cluster")

val trackerConf = overridedParams.get("tracker_conf") match {
case None => TrackerConf()
case Some(conf: TrackerConf) => conf
case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " +
"instance of TrackerConf.")
}
val checkpointParam =
ExternalCheckpointParams.extractParams(overridedParams)

val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0)
.asInstanceOf[Double]
val checkpointParam = ExternalCheckpointParams.extractParams(overridedParams)

val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long]
val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,13 @@ private[spark] trait LearningTaskParams extends Params {
/**
* Fraction of training points to use for testing.
*/
@Deprecated
final val trainTestRatio = new DoubleParam(this, "trainTestRatio",
"fraction of training points to use for testing",
ParamValidators.inRange(0, 1))
setDefault(trainTestRatio, 1.0)

@Deprecated
final def getTrainTestRatio: Double = $(trainTestRatio)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,15 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll {
classifier.getBaseScore
}
}

test("approx can't be used for gpu train") {
val paramMap = Map("tree_method" -> "approx", "device" -> "cuda")
val trainingDF = buildDataFrame(MultiClassification.train)
val xgb = new XGBoostClassifier(paramMap)
val thrown = intercept[IllegalArgumentException] {
xgb.fit(trainingDF)
}
assert(thrown.getMessage.contains("The tree method \"approx\" is not yet supported " +
"for Spark GPU cluster"))
}
}
Loading