From 8f0efb4ab3ab21e527ba1f32c03cac7cd1ac8bd2 Mon Sep 17 00:00:00 2001 From: Bobby Wang Date: Thu, 27 Jul 2023 11:09:55 +0800 Subject: [PATCH] [jvm-packages] automatically set the max/min direction for best score (#9404) --- .../dmlc/xgboost4j/scala/spark/XGBoost.scala | 23 +--- .../spark/params/LearningTaskParams.scala | 4 - .../java/ml/dmlc/xgboost4j/java/XGBoost.java | 84 +++++++++--- .../ml/dmlc/xgboost4j/java/XGBoostTest.java | 121 ++++++++++++++++++ 4 files changed, 192 insertions(+), 40 deletions(-) create mode 100644 jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/XGBoostTest.java diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala index f514eaa68b20..2f1f261fb77e 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/XGBoost.scala @@ -23,7 +23,6 @@ import scala.util.Random import scala.collection.JavaConverters._ import ml.dmlc.xgboost4j.java.{Communicator, IRabitTracker, XGBoostError, RabitTracker => PyRabitTracker} -import ml.dmlc.xgboost4j.scala.spark.params.LearningTaskParams import ml.dmlc.xgboost4j.scala.ExternalCheckpointManager import ml.dmlc.xgboost4j.scala.{XGBoost => SXGBoost, _} import ml.dmlc.xgboost4j.{LabeledPoint => XGBLabeledPoint} @@ -55,9 +54,6 @@ object TrackerConf { def apply(): TrackerConf = TrackerConf(0L) } -private[scala] case class XGBoostExecutionEarlyStoppingParams(numEarlyStoppingRounds: Int, - maximizeEvalMetrics: Boolean) - private[scala] case class XGBoostExecutionInputParams(trainTestRatio: Double, seed: Long) private[scala] case class XGBoostExecutionParams( @@ -71,7 +67,7 @@ private[scala] case class XGBoostExecutionParams( trackerConf: TrackerConf, checkpointParam: Option[ExternalCheckpointParams], xgbInputParams: XGBoostExecutionInputParams, - earlyStoppingParams: XGBoostExecutionEarlyStoppingParams, + earlyStoppingRounds: Int, cacheTrainingSet: Boolean, device: Option[String], isLocal: Boolean, @@ -146,15 +142,8 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s val numEarlyStoppingRounds = overridedParams.getOrElse( "num_early_stopping_rounds", 0).asInstanceOf[Int] overridedParams += "num_early_stopping_rounds" -> numEarlyStoppingRounds - if (numEarlyStoppingRounds > 0 && - !overridedParams.contains("maximize_evaluation_metrics")) { - if (overridedParams.getOrElse("custom_eval", null) != null) { + if (numEarlyStoppingRounds > 0 && overridedParams.getOrElse("custom_eval", null) != null) { throw new IllegalArgumentException("custom_eval does not support early stopping") - } - val eval_metric = overridedParams("eval_metric").toString - val maximize = LearningTaskParams.evalMetricsToMaximize contains eval_metric - logger.info("parameter \"maximize_evaluation_metrics\" is set to " + maximize) - overridedParams += ("maximize_evaluation_metrics" -> maximize) } overridedParams } @@ -213,10 +202,6 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s val earlyStoppingRounds = overridedParams.getOrElse( "num_early_stopping_rounds", 0).asInstanceOf[Int] - val maximizeEvalMetrics = overridedParams.getOrElse( - "maximize_evaluation_metrics", true).asInstanceOf[Boolean] - val xgbExecEarlyStoppingParams = XGBoostExecutionEarlyStoppingParams(earlyStoppingRounds, - maximizeEvalMetrics) val cacheTrainingSet = overridedParams.getOrElse("cache_training_set", false) .asInstanceOf[Boolean] @@ -232,7 +217,7 @@ private[this] class XGBoostExecutionParamsFactory(rawParams: Map[String, Any], s missing, allowNonZeroForMissing, trackerConf, checkpointParam, inputParams, - xgbExecEarlyStoppingParams, + earlyStoppingRounds, cacheTrainingSet, device, isLocal, @@ -319,7 +304,7 @@ object XGBoost extends Serializable { watches = buildWatchesAndCheck(buildWatches) - val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingParams.numEarlyStoppingRounds + val numEarlyStoppingRounds = xgbExecutionParam.earlyStoppingRounds val metrics = Array.tabulate(watches.size)(_ => Array.ofDim[Float](numRounds)) val externalCheckpointParams = xgbExecutionParam.checkpointParam diff --git a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala index 6aec4d36ed6f..bcbd7548f644 100644 --- a/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala +++ b/jvm-packages/xgboost4j-spark/src/main/scala/ml/dmlc/xgboost4j/scala/spark/params/LearningTaskParams.scala @@ -112,8 +112,4 @@ private[spark] object LearningTaskParams { val supportedObjectiveType = HashSet("regression", "classification") - val evalMetricsToMaximize = HashSet("auc", "aucpr", "ndcg", "map") - - val evalMetricsToMinimize = HashSet("rmse", "rmsle", "mae", "mape", "logloss", "error", "merror", - "mlogloss", "gamma-deviance") } diff --git a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java index d765a3cab21e..bcd0b1b11d2f 100644 --- a/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java +++ b/jvm-packages/xgboost4j/src/main/java/ml/dmlc/xgboost4j/java/XGBoost.java @@ -17,6 +17,8 @@ import java.io.*; import java.util.*; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -30,6 +32,11 @@ public class XGBoost { private static final Log logger = LogFactory.getLog(XGBoost.class); + public static final String[] MAXIMIZ_METRICES = { + "auc", "aucpr", "pre", "pre@", "map", "ndcg", + "auc@", "aucpr@", "map@", "ndcg@", + }; + /** * load model from modelPath * @@ -158,7 +165,7 @@ public static Booster trainAndSaveCheckpoint( //collect eval matrixs String[] evalNames; DMatrix[] evalMats; - float bestScore; + float bestScore = 1; int bestIteration; List names = new ArrayList(); List mats = new ArrayList(); @@ -175,11 +182,7 @@ public static Booster trainAndSaveCheckpoint( evalNames = names.toArray(new String[names.size()]); evalMats = mats.toArray(new DMatrix[mats.size()]); - if (isMaximizeEvaluation(params)) { - bestScore = -Float.MAX_VALUE; - } else { - bestScore = Float.MAX_VALUE; - } + bestIteration = 0; metrics = metrics == null ? new float[evalNames.length][numRounds] : metrics; @@ -210,6 +213,9 @@ public static Booster trainAndSaveCheckpoint( checkpointIterations = new HashSet<>(ecm.getCheckpointRounds(checkpointInterval, numRounds)); } + boolean initial_best_score_flag = false; + boolean max_direction = false; + // begin to train for (int iter = booster.getVersion() / 2; iter < numRounds; iter++) { if (booster.getVersion() % 2 == 0) { @@ -231,6 +237,18 @@ public static Booster trainAndSaveCheckpoint( } else { evalInfo = booster.evalSet(evalMats, evalNames, iter, metricsOut); } + + if (!initial_best_score_flag) { + if (isMaximizeEvaluation(evalInfo, evalNames, params)) { + max_direction = true; + bestScore = -Float.MAX_VALUE; + } else { + max_direction = false; + bestScore = Float.MAX_VALUE; + } + initial_best_score_flag = true; + } + for (int i = 0; i < metricsOut.length; i++) { metrics[i][iter] = metricsOut[i]; } @@ -238,7 +256,7 @@ public static Booster trainAndSaveCheckpoint( // If there is more than one evaluation datasets, the last one would be used // to determinate early stop. float score = metricsOut[metricsOut.length - 1]; - if (isMaximizeEvaluation(params)) { + if (max_direction) { // Update best score if the current score is better (no update when equal) if (score > bestScore) { bestScore = score; @@ -264,9 +282,7 @@ public static Booster trainAndSaveCheckpoint( break; } if (Communicator.getRank() == 0 && shouldPrint(params, iter)) { - if (shouldPrint(params, iter)){ - Communicator.communicatorPrint(evalInfo + '\n'); - } + Communicator.communicatorPrint(evalInfo + '\n'); } } booster.saveRabitCheckpoint(); @@ -360,16 +376,50 @@ static boolean shouldEarlyStop(int earlyStoppingRounds, int iter, int bestIterat return iter - bestIteration >= earlyStoppingRounds; } - private static boolean isMaximizeEvaluation(Map params) { - try { + private static String getMetricNameFromlog(String evalInfo, String[] evalNames) { + String regexPattern = Pattern.quote(evalNames[0]) + "-(.*):"; + Pattern pattern = Pattern.compile(regexPattern); + Matcher matcher = pattern.matcher(evalInfo); + + String metricName = null; + if (matcher.find()) { + metricName = matcher.group(1); + logger.debug("Got the metric name: " + metricName); + } + return metricName; + } + + // visiable for testing + public static boolean isMaximizeEvaluation(String evalInfo, + String[] evalNames, + Map params) { + + String metricName; + + if (params.get("maximize_evaluation_metrics") != null) { + // user has forced the direction no matter what is the metric name. String maximize = String.valueOf(params.get("maximize_evaluation_metrics")); - assert(maximize != null); return Boolean.valueOf(maximize); - } catch (Exception ex) { - logger.error("maximize_evaluation_metrics has to be specified for enabling early stop," + - " allowed value: true/false", ex); - throw ex; } + + if (params.get("eval_metric") != null) { + // user has special metric name + metricName = String.valueOf(params.get("eval_metric")); + } else { + // infer the metric name from log + metricName = getMetricNameFromlog(evalInfo, evalNames); + } + + assert metricName != null; + + if (!"mape".equals(metricName)) { + for (String x : MAXIMIZ_METRICES) { + if (metricName.startsWith(x)) { + return true; + } + } + } + return false; } /** diff --git a/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/XGBoostTest.java b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/XGBoostTest.java new file mode 100644 index 000000000000..190405c68d87 --- /dev/null +++ b/jvm-packages/xgboost4j/src/test/java/ml/dmlc/xgboost4j/java/XGBoostTest.java @@ -0,0 +1,121 @@ +/* + Copyright (c) 2023 by Contributors + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + */ + +package ml.dmlc.xgboost4j.java; + +import junit.framework.TestCase; +import ml.dmlc.xgboost4j.LabeledPoint; +import org.junit.Test; + +import java.util.HashMap; +import java.util.Map; +import java.util.Random; + +public class XGBoostTest { + + private String composeEvalInfo(String metric, String evalName) { + return "[0]\t" + evalName + "-" + metric + ":" + "\ttest"; + } + + @Test + public void testIsMaximizeEvaluation() { + String[] minimum_metrics = {"mape", "logloss", "error", "others"}; + String[] evalNames = {"set-abc"}; + + HashMap params = new HashMap<>(); + + // test1, infer the metric from faked log + for (String x : XGBoost.MAXIMIZ_METRICES) { + String evalInfo = composeEvalInfo(x, evalNames[0]); + TestCase.assertTrue(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params)); + } + + // test2, the direction for mape should be minimum + String evalInfo = composeEvalInfo("mape", evalNames[0]); + TestCase.assertFalse(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params)); + + // test3, force maximize_evaluation_metrics + params.clear(); + params.put("maximize_evaluation_metrics", true); + // auc should be max, + evalInfo = composeEvalInfo("auc", evalNames[0]); + TestCase.assertTrue(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params)); + + params.clear(); + params.put("maximize_evaluation_metrics", false); + // auc should be min, + evalInfo = composeEvalInfo("auc", evalNames[0]); + TestCase.assertFalse(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params)); + + // test4, set the metric manually + for (String x : XGBoost.MAXIMIZ_METRICES) { + params.clear(); + params.put("eval_metric", x); + evalInfo = composeEvalInfo(x, evalNames[0]); + TestCase.assertTrue(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params)); + } + + // test5, set the metric manually + for (String x : minimum_metrics) { + params.clear(); + params.put("eval_metric", x); + evalInfo = composeEvalInfo(x, evalNames[0]); + TestCase.assertFalse(XGBoost.isMaximizeEvaluation(evalInfo, evalNames, params)); + } + + } + + @Test + public void testEarlyStop() throws XGBoostError { + Random random = new Random(1); + + java.util.ArrayList labelall = new java.util.ArrayList(); + int nrep = 3000; + java.util.List blist = new java.util.LinkedList(); + for (int i = 0; i < nrep; ++i) { + LabeledPoint p = new LabeledPoint( + i % 2, 4, + new int[]{0, 1, 2, 3}, + new float[]{random.nextFloat(), random.nextFloat(), random.nextFloat(), random.nextFloat()}); + blist.add(p); + labelall.add(p.label()); + } + + DMatrix dmat = new DMatrix(blist.iterator(), null); + + int round = 50; + int earlyStop = 2; + + HashMap mapParams = new HashMap<>(); + mapParams.put("eta", 0.1); + mapParams.put("objective", "binary:logistic"); + mapParams.put("max_depth", 3); + mapParams.put("eval_metric", "auc"); + mapParams.put("silent", 0); + + HashMap mapWatches = new HashMap<>(); + mapWatches.put("selTrain-*", dmat); + + try { + Booster booster = XGBoost.train(dmat, mapParams, round, mapWatches, null, null, null, earlyStop); + Map attrs = booster.getAttrs(); + TestCase.assertTrue(Integer.valueOf(attrs.get("best_iteration")) < round - 1); + } catch (Exception e) { + TestCase.assertFalse(false); + } + + } +}