From 69b0377ed29689a68345ca470ab4dd2b4b7b0114 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 14 Aug 2023 10:42:55 +0800 Subject: [PATCH 1/3] [jvm-packages] throw exception when tree_method=approx and device=cuda --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 62 ++++++++++--------- .../spark/params/LearningTaskParams.scala | 2 + .../scala/spark/ParameterSuite.scala | 10 +++ 3 files changed, 46 insertions(+), 28 deletions(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 7bb245035c83..b99344ce3e15 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -93,12 +93,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s private val overridedParams = overrideParams(rawParams, sc) + validateSparkSslConf() + /** * Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true). * If so, throw an exception unless this safety measure has been explicitly overridden * via conf `xgboost.spark.ignoreSsl`. */ - private def validateSparkSslConf: Unit = { + private def validateSparkSslConf(): Unit = { val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) = SparkSession.getActiveSession match { case Some(ss) => @@ -148,55 +150,59 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s overridedParams } + /** + * The Map parameters accepted by estimator's constructor may have string type, + * Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these + * kind of parameters into the correct type in the function. + * + * @return XGBoostExecutionParams + */ def buildXGBRuntimeParams: XGBoostExecutionParams = { + + val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait] + val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait] + if (obj != null) { + require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " + + "is not defined, you have to specify the objective type as classification or regression" + + " with a customized objective function") + } + + var trainTestRatio = 1.0 + if (overridedParams.contains("train_test_ratio")) { + logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + + " pass a training and multiple evaluation datasets by passing 'eval_sets' and " + + "'eval_set_names'") + trainTestRatio = overridedParams.get("train_test_ratio").get.asInstanceOf[Double] + } + val nWorkers = overridedParams("num_workers").asInstanceOf[Int] val round = overridedParams("num_round").asInstanceOf[Int] val useExternalMemory = overridedParams .getOrElse("use_external_memory", false).asInstanceOf[Boolean] - val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait] - val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait] + val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float] val allowNonZeroForMissing = overridedParams .getOrElse("allow_non_zero_for_missing", false) .asInstanceOf[Boolean] - validateSparkSslConf - var treeMethod: Option[String] = None - if (overridedParams.contains("tree_method")) { - require(overridedParams("tree_method") == "hist" || - overridedParams("tree_method") == "approx" || - overridedParams("tree_method") == "auto" || - overridedParams("tree_method") == "gpu_hist", "xgboost4j-spark only supports tree_method" + - " as 'hist', 'approx', 'gpu_hist', and 'auto'") - treeMethod = Some(overridedParams("tree_method").asInstanceOf[String]) - } + val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString) // back-compatible with "gpu_hist" val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) { Some("cuda") } else overridedParams.get("device").map(_.toString) - if (overridedParams.contains("train_test_ratio")) { - logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + - " pass a training and multiple evaluation datasets by passing 'eval_sets' and " + - "'eval_set_names'") - } - require(nWorkers > 0, "you must specify more than 0 workers") - if (obj != null) { - require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " + - "is not defined, you have to specify the objective type as classification or regression" + - " with a customized objective function") - } + require(!(treeMethod.exists(_ == "approx") && device.exists(_ == "cuda")), + "Tree method \"approx\" can't be used for GPU train") + val trackerConf = overridedParams.get("tracker_conf") match { case None => TrackerConf() case Some(conf: TrackerConf) => conf case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " + "instance of TrackerConf.") } - val checkpointParam = - ExternalCheckpointParams.extractParams(overridedParams) - val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0) - .asInstanceOf[Double] + val checkpointParam = ExternalCheckpointParams.extractParams(overridedParams) + val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long] val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index bcbd7548f644..b73e6cbaa844 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -68,11 +68,13 @@ private[spark] trait LearningTaskParams extends Params { /** * Fraction of training points to use for testing. */ + @Deprecated final val trainTestRatio = new DoubleParam(this, "trainTestRatio", "fraction of training points to use for testing", ParamValidators.inRange(0, 1)) setDefault(trainTestRatio, 1.0) + @Deprecated final def getTrainTestRatio: Double = $(trainTestRatio) /** diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index 11b60e74d4ea..2b8d63024367 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -92,4 +92,14 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll { classifier.getBaseScore } } + + test("approx can't be used for gpu train") { + val paramMap = Map("tree_method" -> "approx", "device" -> "cuda") + val trainingDF = buildDataFrame(MultiClassification.train) + val xgb = new XGBoostClassifier(paramMap) + val thrown = intercept[IllegalArgumentException] { + xgb.fit(trainingDF) + } + assert(thrown.getMessage.contains("Tree method \"approx\" can't be used for GPU train")) + } } From 3a246fdfadd2d008ec56901aa8cad1aff903271f Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 14 Aug 2023 16:02:21 +0800 Subject: [PATCH 2/3] Update jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala Co-authored-by: Jiaming Yuan --- .../src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index b99344ce3e15..d1243147980c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -192,7 +192,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s } else overridedParams.get("device").map(_.toString) require(!(treeMethod.exists(_ == "approx") && device.exists(_ == "cuda")), - "Tree method \"approx\" can't be used for GPU train") + "The tree method \"approx\" is not yet supported for Spark GPU cluster") val trackerConf = overridedParams.get("tracker_conf") match { case None => TrackerConf() From 17772523b246f19f5ed448aada8cd007574534d8 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Mon, 14 Aug 2023 16:08:41 +0800 Subject: [PATCH 3/3] fixes --- .../scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index 2b8d63024367..f187f7394ffa 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -100,6 +100,7 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll { val thrown = intercept[IllegalArgumentException] { xgb.fit(trainingDF) } - assert(thrown.getMessage.contains("Tree method \"approx\" can't be used for GPU train")) + assert(thrown.getMessage.contains("The tree method \"approx\" is not yet supported " + + "for Spark GPU cluster")) } }