diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index 7bb245035c83..d1243147980c 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -93,12 +93,14 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s private val overridedParams = overrideParams(rawParams, sc) + validateSparkSslConf() + /** * Check to see if Spark expects SSL encryption (`spark.ssl.enabled` set to true). * If so, throw an exception unless this safety measure has been explicitly overridden * via conf `xgboost.spark.ignoreSsl`. */ - private def validateSparkSslConf: Unit = { + private def validateSparkSslConf(): Unit = { val (sparkSslEnabled: Boolean, xgboostSparkIgnoreSsl: Boolean) = SparkSession.getActiveSession match { case Some(ss) => @@ -148,55 +150,59 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s overridedParams } + /** + * The Map parameters accepted by estimator's constructor may have string type, + * Eg, Map("num_workers" -> "6", "num_round" -> 5), we need to convert these + * kind of parameters into the correct type in the function. + * + * @return XGBoostExecutionParams + */ def buildXGBRuntimeParams: XGBoostExecutionParams = { + + val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait] + val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait] + if (obj != null) { + require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " + + "is not defined, you have to specify the objective type as classification or regression" + + " with a customized objective function") + } + + var trainTestRatio = 1.0 + if (overridedParams.contains("train_test_ratio")) { + logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + + " pass a training and multiple evaluation datasets by passing 'eval_sets' and " + + "'eval_set_names'") + trainTestRatio = overridedParams.get("train_test_ratio").get.asInstanceOf[Double] + } + val nWorkers = overridedParams("num_workers").asInstanceOf[Int] val round = overridedParams("num_round").asInstanceOf[Int] val useExternalMemory = overridedParams .getOrElse("use_external_memory", false).asInstanceOf[Boolean] - val obj = overridedParams.getOrElse("custom_obj", null).asInstanceOf[ObjectiveTrait] - val eval = overridedParams.getOrElse("custom_eval", null).asInstanceOf[EvalTrait] + val missing = overridedParams.getOrElse("missing", Float.NaN).asInstanceOf[Float] val allowNonZeroForMissing = overridedParams .getOrElse("allow_non_zero_for_missing", false) .asInstanceOf[Boolean] - validateSparkSslConf - var treeMethod: Option[String] = None - if (overridedParams.contains("tree_method")) { - require(overridedParams("tree_method") == "hist" || - overridedParams("tree_method") == "approx" || - overridedParams("tree_method") == "auto" || - overridedParams("tree_method") == "gpu_hist", "xgboost4j-spark only supports tree_method" + - " as 'hist', 'approx', 'gpu_hist', and 'auto'") - treeMethod = Some(overridedParams("tree_method").asInstanceOf[String]) - } + val treeMethod: Option[String] = overridedParams.get("tree_method").map(_.toString) // back-compatible with "gpu_hist" val device: Option[String] = if (treeMethod.exists(_ == "gpu_hist")) { Some("cuda") } else overridedParams.get("device").map(_.toString) - if (overridedParams.contains("train_test_ratio")) { - logger.warn("train_test_ratio is deprecated since XGBoost 0.82, we recommend to explicitly" + - " pass a training and multiple evaluation datasets by passing 'eval_sets' and " + - "'eval_set_names'") - } - require(nWorkers > 0, "you must specify more than 0 workers") - if (obj != null) { - require(overridedParams.get("objective_type").isDefined, "parameter \"objective_type\" " + - "is not defined, you have to specify the objective type as classification or regression" + - " with a customized objective function") - } + require(!(treeMethod.exists(_ == "approx") && device.exists(_ == "cuda")), + "The tree method \"approx\" is not yet supported for Spark GPU cluster") + val trackerConf = overridedParams.get("tracker_conf") match { case None => TrackerConf() case Some(conf: TrackerConf) => conf case _ => throw new IllegalArgumentException("parameter \"tracker_conf\" must be an " + "instance of TrackerConf.") } - val checkpointParam = - ExternalCheckpointParams.extractParams(overridedParams) - val trainTestRatio = overridedParams.getOrElse("train_test_ratio", 1.0) - .asInstanceOf[Double] + val checkpointParam = ExternalCheckpointParams.extractParams(overridedParams) + val seed = overridedParams.getOrElse("seed", System.nanoTime()).asInstanceOf[Long] val inputParams = XGBoostExecutionInputParams(trainTestRatio, seed) diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index bcbd7548f644..b73e6cbaa844 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -68,11 +68,13 @@ private[spark] trait LearningTaskParams extends Params { /** * Fraction of training points to use for testing. */ + @Deprecated final val trainTestRatio = new DoubleParam(this, "trainTestRatio", "fraction of training points to use for testing", ParamValidators.inRange(0, 1)) setDefault(trainTestRatio, 1.0) + @Deprecated final def getTrainTestRatio: Double = $(trainTestRatio) /** diff --git a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala index 11b60e74d4ea..f187f7394ffa 100644 --- a/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala +++ b/jvm-packages/xgboost4j-spark/src/test/scala/ml/dmlc/xgboost4j/scala/spark/ParameterSuite.scala @@ -92,4 +92,15 @@ class ParameterSuite extends AnyFunSuite with PerTest with BeforeAndAfterAll { classifier.getBaseScore } } + + test("approx can't be used for gpu train") { + val paramMap = Map("tree_method" -> "approx", "device" -> "cuda") + val trainingDF = buildDataFrame(MultiClassification.train) + val xgb = new XGBoostClassifier(paramMap) + val thrown = intercept[IllegalArgumentException] { + xgb.fit(trainingDF) + } + assert(thrown.getMessage.contains("The tree method \"approx\" is not yet supported " + + "for Spark GPU cluster")) + } }