diff --git a/docs/algo/sona/feature_gbdt_sona.md b/docs/algo/sona/feature_gbdt_sona.md index 9386da91c..0bae1be94 100644 --- a/docs/algo/sona/feature_gbdt_sona.md +++ b/docs/algo/sona/feature_gbdt_sona.md @@ -55,42 +55,75 @@ GBDT的训练方法中,核心是一种叫梯度直方图的数据结构,需 ### 参数 * **算法参数** - * ml.num.class:分裂数量 + * ml.gbdt.task.type:任务类型,分类或者回归 * ml.gbdt.loss.func:代价函数,支持二分类(binary:logistic)、多分类(multi:logistic)和均方根误差(rmse) * ml.gbdt.eval.metric:模型指标,支持rmse、error、log-loss、cross-entropy、precision和auc + * ml.num.class:分类数量,仅对分类任务有用 * ml.gbdt.feature.sample.ratio:特征采样比例(0到1之间) * ml.gbdt.tree.num:树的数量 * ml.gbdt.tree.depth:树的最大高度 * ml.gbdt.split.num:每个特征的分裂点的数量 * ml.learn.rate:学习速率 + * ml.gbdt.min.node.instance:叶子节点上数据的最少数量 + * ml.gbdt.min.split.gain:分裂需要的最小增益 + * ml.gbdt.reg.lambda:正则化系数 + * ml.gbdt.multi.class.strategy:多分类任务的策略,一轮一棵树(one-tree)或者一轮多棵树(multi-tree) * **输入输出参数** * angel.train.data.path:训练数据的输入路径 * angel.validate.data.path:验证数据的输入路径 + * angel.predict.data.path:预测数据的输入路径 + * angel.predict.out.path:预测结果的保存路径 * angel.save.model.path:训练完成后,模型的保存路径 + * angel.load.model.path:预测开始前,模型的加载路径 ### 训练任务启动命令示例 使用spark提交任务 - + ./spark-submit \ --master yarn-cluster \ --conf spark.ps.jars=$SONA_ANGEL_JARS \ - --jars $SONA_SPARK_JARS \ - --name "LR Adam on Spark-on-Angel" \ + --conf spark.ps.cores=1 \ + --conf spark.ps.memory=10g \ + --conf spark.ps.log.level=INFO \ + --queue $queue \ + --jars $SONA_SPARK_JARS \ + --name "GBDT on Spark-on-Angel" \ --driver-memory 5g \ --num-executors 10 \ --executor-cores 1 \ --executor-memory 10g \ --class com.tencent.angel.spark.ml.tree.gbdt.trainer.GBDTTrainer \ spark-on-angel-mllib-${ANGEL_VERSION}.jar \ + ml.gbdt.task.type:classification \ angel.train.data.path:XXX angel.validate.data.path:XXX angel.save.model.path:XXX \ ml.gbdt.loss.func:binary:logistic ml.gbdt.eval.metric:error,log-loss \ - ml.learn.rate:0.1 ml.gbdt.split.num:10 ml.gbdt.tree.num:20 ml.gbdt.tree.depth:7 ml.class.num:2 \ - ml.feature.index.range:47237 ml.gbdt.feature.sample.ratio:1.0 - - + ml.learn.rate:0.1 ml.gbdt.split.num:10 ml.gbdt.tree.num:20 ml.gbdt.tree.depth:7 ml.num.class:2 \ + ml.feature.index.range:47237 ml.gbdt.feature.sample.ratio:1.0 ml.gbdt.multi.class.strategy:one-tree ml.gbdt.min.node.instance:100 + +### 预测任务启动命令示例 +使用spark提交任务 + + ./spark-submit \ + --master yarn-cluster \ + --conf spark.ps.jars=$SONA_ANGEL_JARS \ + --conf spark.ps.cores=1 \ + --conf spark.ps.memory=10g \ + --conf spark.ps.log.level=INFO \ + --queue $queue \ + --jars $SONA_SPARK_JARS \ + --name "GBDT on Spark-on-Angel" \ + --driver-memory 5g \ + --num-executors 10 \ + --executor-cores 1 \ + --executor-memory 10g \ + --class com.tencent.angel.spark.ml.tree.gbdt.predictor.GBDTPredictor \ + spark-on-angel-mllib-${ANGEL_VERSION}.jar \ + angel.load.model.path:XXX angel.predict.data.path:XXX angel.predict.out.path:XXX \ + + ## 5. 性能 diff --git a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/common/TreeConf.java b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/common/TreeConf.java index 5170a059d..1f8cce8d6 100644 --- a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/common/TreeConf.java +++ b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/common/TreeConf.java @@ -25,6 +25,10 @@ public class TreeConf { public static final String ML_TRAIN_PATH = "spark.ml.train.path"; public static final String ML_VALID_PATH = "spark.ml.valid.path"; public static final String ML_PREDICT_PATH = "spark.ml.predict.path"; + public static final String ML_OUTPUT_PATH = "spark.ml.output.path"; + + public static final String ML_GBDT_TASK_TYPE = "ml.gbdt.task.type"; + public static final String DEFAULT_ML_GBDT_TASK_TYPE = "classification"; public static final String ML_VALID_DATA_RATIO = "spark.ml.valid.ratio"; public static final double DEFAULT_ML_VALID_DATA_RATIO = 0.25; public static final String ML_NUM_CLASS = "spark.ml.class.num"; diff --git a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/examples/GBDTRegression.scala b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/examples/GBDTRegression.scala new file mode 100644 index 000000000..289575649 --- /dev/null +++ b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/examples/GBDTRegression.scala @@ -0,0 +1,102 @@ +/* + * Tencent is pleased to support the open source community by making Angel available. + * + * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in + * compliance with the License. You may obtain a copy of the License at + * + * https://opensource.org/licenses/Apache-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + * + */ + +package com.tencent.angel.spark.ml.tree.gbdt.examples + +import com.tencent.angel.spark.ml.core.ArgsUtil +import com.tencent.angel.spark.ml.tree.gbdt.trainer.GBDTTrainer +import com.tencent.angel.spark.ml.tree.param.GBDTParam +import com.tencent.angel.spark.ml.tree.util.Maths +import org.apache.spark.{SparkConf, SparkContext} + +object GBDTRegression { + + def main(args: Array[String]): Unit = { + + @transient val conf = new SparkConf().setMaster("local").setAppName("gbdt") + + val param = new GBDTParam + + // spark conf + val numExecutor = 1 + val numCores = 1 + param.numWorker = numExecutor + param.numThread = numCores + conf.set("spark.task.cpus", numCores.toString) + conf.set("spark.locality.wait", "0") + conf.set("spark.memory.fraction", "0.7") + conf.set("spark.memory.storageFraction", "0.8") + conf.set("spark.task.maxFailures", "1") + conf.set("spark.yarn.maxAppAttempts", "1") + conf.set("spark.network.timeout", "1000") + conf.set("spark.executor.heartbeatInterval", "500") + + val params = ArgsUtil.parse(args) + + //val trainPath = "data/dna/dna.scale" //dimension=181 + //val validPath = "data/dna/dna.scale.t" + val trainPath = "data/abalone/abalone_8d_train.libsvm" //dimension=8 + val validPath = "data/abalone/abalone_8d_train.libsvm" + val modelPath = "tmp/gbdt/abalone" + + // dataset conf + param.taskType = "regression" + param.numClass = 2 + param.numFeature = 8 + + // loss and metric + param.lossFunc = "rmse" + param.evalMetrics = Array("rmse") + param.multiGradCache = false + + // major algo conf + param.featSampleRatio = 1.0f + param.learningRate = 0.1f + param.numSplit = 10 + param.numTree = 10 + param.maxDepth = 7 + val maxNodeNum = Maths.pow(2, param.maxDepth + 1) - 1 + param.maxNodeNum = 4096 min maxNodeNum + + // less important algo conf + param.histSubtraction = true + param.lighterChildFirst = true + param.fullHessian = false + param.minChildWeight = 0.0f + param.minNodeInstance = 10 + param.minSplitGain = 0.0f + param.regAlpha = 0.0f + param.regLambda = 1.0f + param.maxLeafWeight = 0.0f + + println(s"Hyper-parameters:\n$param") + + @transient implicit val sc = new SparkContext(conf) + + try { + val trainer = new GBDTTrainer(param) + trainer.initialize(trainPath, validPath) + val model = trainer.train() + trainer.save(model, modelPath) + } catch { + case e: Exception => + e.printStackTrace() + } finally { + } + } + +} diff --git a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/histogram/Histogram.java b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/histogram/Histogram.java index b6aa657f6..f96515ee5 100644 --- a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/histogram/Histogram.java +++ b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/histogram/Histogram.java @@ -36,6 +36,7 @@ public Histogram(int numBin, int numClass, boolean fullHessian, boolean multiCla this.numClass = numClass; this.fullHessian = fullHessian; this.multiClassMultiTree = multiClassMultiTree; + if (numClass == 2 || multiClassMultiTree) { this.gradients = new double[numBin]; this.hessians = new double[numBin]; diff --git a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/learner/FPGBDTLearner.scala b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/learner/FPGBDTLearner.scala deleted file mode 100644 index 388868e69..000000000 --- a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/learner/FPGBDTLearner.scala +++ /dev/null @@ -1,378 +0,0 @@ -/* - * Tencent is pleased to support the open source community by making Angel available. - * - * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * https://opensource.org/licenses/Apache-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License - * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express - * or implied. See the License for the specific language governing permissions and limitations under - * the License. - * - */ - - -package com.tencent.angel.spark.ml.tree.gbdt.learner - -import java.{util => ju} - -import com.tencent.angel.spark.ml.tree.data.DataSet -import com.tencent.angel.spark.ml.tree.gbdt.histogram._ -import com.tencent.angel.spark.ml.tree.gbdt.metadata.{DataInfo, FeatureInfo} -import com.tencent.angel.spark.ml.tree.gbdt.tree.{GBTNode, GBTSplit, GBTTree} -import com.tencent.angel.spark.ml.tree.objective.ObjectiveFactory -import com.tencent.angel.spark.ml.tree.objective.metric.EvalMetric -import com.tencent.angel.spark.ml.tree.param.GBDTParam -import com.tencent.angel.spark.ml.tree.split.SplitEntry -import com.tencent.angel.spark.ml.tree.util.{EvenPartitioner, Maths, RangeBitSet} -import org.apache.spark.ml.linalg.Vector - -import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConverters._ -import scala.util.Random - -class FPGBDTLearner(val learnerId: Int, - val param: GBDTParam, - _featureInfo: FeatureInfo, - _trainData: DataSet, - _labels: Array[Float], - _validData: Array[Vector], _validLabel: Array[Float]) { - - @transient private[learner] val forest = ArrayBuffer[GBTTree]() - - @transient private val trainData: DataSet = _trainData - @transient private val labels: Array[Float] = _labels - @transient private val validData: Array[Vector] = _validData - @transient private val validLabels: Array[Float] = _validLabel - @transient private val validPreds = { - if (param.numClass == 2) - new Array[Float](validData.length) - else - new Array[Float](validData.length * param.numClass) - } - - private[learner] val (featLo, featHi) = { - val featureEdges = new EvenPartitioner(param.numFeature, param.numWorker).partitionEdges() - (featureEdges(learnerId), featureEdges(learnerId + 1)) - } - private[learner] val numFeatUsed = Math.round((featHi - featLo) * param.featSampleRatio) - private[learner] val isFeatUsed = { - if (numFeatUsed == featHi - featLo) - (featLo until featHi).map(fid => _featureInfo.getNumBin(fid) > 0).toArray - else - new Array[Boolean](featHi - featLo) - } - private[learner] val featureInfo: FeatureInfo = _featureInfo - private[learner] val dataInfo = DataInfo(param, labels.length) - - private[learner] val loss = ObjectiveFactory.getLoss(param.lossFunc) - private[learner] val evalMetrics = ObjectiveFactory.getEvalMetricsOrDefault(param.evalMetrics, loss) - - // histograms and global best splits, one for each internal tree node - private[learner] val storedHists = new Array[Array[Histogram]](Maths.pow(2, param.maxDepth) - 1) - private[learner] val bestSplits = new Array[GBTSplit](Maths.pow(2, param.maxDepth) - 1) - private[learner] val histBuilder = new HistBuilder(param) - private[learner] val splitFinder = new SplitFinder(param) - - private[learner] val activeNodes = ArrayBuffer[Int]() - - private[learner] val buildHistTime = new Array[Long](Maths.pow(2, param.maxDepth) - 1) - private[learner] val histSubtractTime = new Array[Long](Maths.pow(2, param.maxDepth) - 1) - private[learner] val findSplitTime = new Array[Long](Maths.pow(2, param.maxDepth) - 1) - private[learner] val getSplitResultTime = new Array[Long](Maths.pow(2, param.maxDepth) - 1) - private[learner] val splitNodeTime = new Array[Long](Maths.pow(2, param.maxDepth) - 1) - - def timing[A](f: => A)(t: Long => Any): A = { - val t0 = System.currentTimeMillis() - val res = f - t(System.currentTimeMillis() - t0) - res - } - - def reportTime(): String = { - val sb = new StringBuilder - for (depth <- 0 until param.maxDepth) { - val from = Maths.pow(2, depth) - 1 - val until = Maths.pow(2, depth + 1) - 1 - if (from < Maths.pow(2, param.maxDepth) - 1) { - sb.append(s"Layer${depth + 1}:\n") - sb.append(s"|buildHistTime: [${buildHistTime.slice(from, until).mkString(", ")}], " + - s"sum[${buildHistTime.slice(from, until).sum}]\n") - sb.append(s"|histSubtractTime: [${histSubtractTime.slice(from, until).mkString(", ")}], " + - s"sum[${histSubtractTime.slice(from, until).sum}]\n") - sb.append(s"|findSplitTime: [${findSplitTime.slice(from, until).mkString(", ")}], " + - s"sum[${findSplitTime.slice(from, until).sum}]\n") - sb.append(s"|getSplitResultTime: [${getSplitResultTime.slice(from, until).mkString(", ")}], " + - s"sum[${getSplitResultTime.slice(from, until).sum}]\n") - sb.append(s"|splitNodeTime: [${splitNodeTime.slice(from, until).mkString(", ")}], " + - s"sum[${splitNodeTime.slice(from, until).sum}]\n") - } - } - val res = sb.toString() - println(res) - for (i <- buildHistTime.indices) { - buildHistTime(i) = 0 - histSubtractTime(i) = 0 - findSplitTime(i) = 0 - getSplitResultTime(i) = 0 - splitNodeTime(i) = 0 - } - res - } - - def createNewTree(): Unit = { - // 1. create new tree - val tree = new GBTTree(param) - this.forest += tree - // 2. sample features - if (numFeatUsed != featHi - featLo) { - ju.Arrays.fill(isFeatUsed, false) - for (_ <- 0 until numFeatUsed) { - val rand = Random.nextInt(featHi - featLo) - isFeatUsed(rand) = featureInfo.getNumBin(featLo + rand) > 0 - } - } - // 3. reset position info - dataInfo.resetPosInfo() - // 4. calc grads - val sumGradPair = dataInfo.calcGradPairs(0, labels, loss, param) - tree.getRoot.setSumGradPair(sumGradPair) - // 5. set root status - activeNodes += 0 - } - - def findSplits(): Seq[(Int, GBTSplit)] = { - val res = if (activeNodes.nonEmpty) { - buildHistAndFindSplit(activeNodes) - } else { - Seq.empty - } - activeNodes.clear() - res - } - - def getSplitResults(splits: Seq[(Int, GBTSplit)]): Seq[(Int, RangeBitSet)] = { - val tree = forest.last - splits.map { - case (nid, split) => - tree.getNode(nid).setSplitEntry(split.getSplitEntry) - bestSplits(nid) = split - (nid, getSplitResult(nid, split.getSplitEntry)) - }.filter(_._2 != null) - } - - def splitNodes(splitResults: Seq[(Int, RangeBitSet)]): Boolean = { - splitResults.foreach { - case (nid, result) => - splitNode(nid, result, bestSplits(nid)) - if (2 * nid + 1 < Maths.pow(2, param.maxDepth) - 1) { - activeNodes += 2 * nid + 1 - activeNodes += 2 * nid + 2 - } - } - activeNodes.nonEmpty - } - - def buildHistAndFindSplit(nids: Seq[Int]): Seq[(Int, GBTSplit)] = { - val nodes = nids.map(forest.last.getNode) - val sumGradPairs = nodes.map(_.getSumGradPair) - val canSplits = nodes.map(canSplitNode) - - val buildStart = System.currentTimeMillis() - var cur = 0 - while (cur < nids.length) { - val nid = nids(cur) - val sibNid = Maths.sibling(nid) - if (cur + 1 < nids.length && nids(cur + 1) == sibNid) { - if (canSplits(cur) || canSplits(cur + 1)) { - val curSize = dataInfo.getNodeSize(nid) - val sibSize = dataInfo.getNodeSize(sibNid) - val parNid = Maths.parent(nid) - val parHist = storedHists(parNid) - if (curSize < sibSize) { - timing { - storedHists(nid) = histBuilder.buildHistogramsFP( - isFeatUsed, featLo, trainData, featureInfo, dataInfo, - nid, sumGradPairs(cur), parHist - ) - storedHists(sibNid) = parHist - } { t => buildHistTime(nid) = t } - // timing(storedHists(nid) = histBuilder.buildHistogramsFP( - // isFeatUsed, featLo, trainData, featureInfo, dataInfo, - // nid, sumGradPairs(cur) - // )) {t => buildHistTime(nid) = t} - // timing(storedHists(sibNid) = histBuilder.histSubtraction( - // parHist, storedHists(nid), true - // )) {t => histSubtractTime(sibNid) = t} - } else { - timing { - storedHists(sibNid) = histBuilder.buildHistogramsFP( - isFeatUsed, featLo, trainData, featureInfo, dataInfo, - sibNid, sumGradPairs(cur + 1), parHist - ) - storedHists(nid) = parHist - } { t => buildHistTime(sibNid) = t } - // timing(storedHists(sibNid) = histBuilder.buildHistogramsFP( - // isFeatUsed, featLo, trainData, featureInfo, dataInfo, - // sibNid, sumGradPairs(cur + 1) - // )) {t => histSubtractTime(sibNid) = t} - // timing(storedHists(nid) = histBuilder.histSubtraction( - // parHist, storedHists(sibNid), true - // )) {t => buildHistTime(nid) = t} - } - storedHists(parNid) = null - } - cur += 2 - } else { - if (canSplits(cur)) { - timing(storedHists(nid) = histBuilder.buildHistogramsFP( - isFeatUsed, featLo, trainData, featureInfo, dataInfo, - nid, sumGradPairs(cur), null - )) { t => buildHistTime(nid) = t } - } - cur += 1 - } - } - println(s"Build histograms cost ${System.currentTimeMillis() - buildStart} ms") - - val findStart = System.currentTimeMillis() - val res = canSplits.zipWithIndex.map { - case (canSplit, i) => - val nid = nids(i) - timing(if (canSplit) { - val node = nodes(i) - val hist = storedHists(nid) - val sumGradPair = sumGradPairs(i) - val nodeGain = node.calcGain(param) - val split = splitFinder.findBestSplitFP(featLo, hist, - featureInfo, sumGradPair, nodeGain) - (nid, split) - } else { - (nid, new GBTSplit) - }) { t => findSplitTime(nid) = t } - }.filter(_._2.isValid(param.minSplitGain)) - println(s"Find splits cost ${System.currentTimeMillis() - findStart} ms") - res - } - - def getSplitResult(nid: Int, splitEntry: SplitEntry): RangeBitSet = { - require(!splitEntry.isEmpty && splitEntry.getGain > param.minSplitGain) - //forest.last.getNode(nid).setSplitEntry(splitEntry) - val splitFid = splitEntry.getFid - if (featLo <= splitFid && splitFid < featHi) { - val splits = featureInfo.getSplits(splitFid) - timing(dataInfo.getSplitResult(nid, splitEntry, splits, trainData)) { t => getSplitResultTime(nid) = t } - } else { - null - } - } - - def splitNode(nid: Int, splitResult: RangeBitSet, split: GBTSplit = null): Unit = { - timing { - dataInfo.updatePos(nid, splitResult) - val tree = forest.last - val node = tree.getNode(nid) - val leftChild = new GBTNode(2 * nid + 1, node, param.numClass) - val rightChild = new GBTNode(2 * nid + 2, node, param.numClass) - node.setLeftChild(leftChild) - node.setRightChild(rightChild) - tree.setNode(2 * nid + 1, leftChild) - tree.setNode(2 * nid + 2, rightChild) - if (split == null) { - val leftSize = dataInfo.getNodeSize(2 * nid + 1) - val rightSize = dataInfo.getNodeSize(2 * nid + 2) - if (leftSize < rightSize) { - val leftSumGradPair = dataInfo.sumGradPair(2 * nid + 1) - val rightSumGradPair = node.getSumGradPair.subtract(leftSumGradPair) - leftChild.setSumGradPair(leftSumGradPair) - rightChild.setSumGradPair(rightSumGradPair) - } else { - val rightSumGradPair = dataInfo.sumGradPair(2 * nid + 2) - val leftSumGradPair = node.getSumGradPair.subtract(rightSumGradPair) - leftChild.setSumGradPair(leftSumGradPair) - rightChild.setSumGradPair(rightSumGradPair) - } - } else { - leftChild.setSumGradPair(split.getLeftGradPair) - rightChild.setSumGradPair(split.getRightGradPair) - } - } { t => splitNodeTime(nid) = t } - } - - def canSplitNode(node: GBTNode): Boolean = { - if (dataInfo.getNodeSize(node.getNid) > param.minNodeInstance) { - if (param.numClass == 2) { - val sumGradPair = node.getSumGradPair.asInstanceOf[BinaryGradPair] - param.satisfyWeight(sumGradPair.getGrad, sumGradPair.getHess) - } else { - val sumGradPair = node.getSumGradPair.asInstanceOf[MultiGradPair] - param.satisfyWeight(sumGradPair.getGrad, sumGradPair.getHess) - } - } else { - false - } - } - - def setAsLeaf(nid: Int): Unit = setAsLeaf(nid, forest.last.getNode(nid)) - - def setAsLeaf(nid: Int, node: GBTNode): Unit = { - node.chgToLeaf() - if (param.numClass == 2) { - val weight = node.calcWeight(param) - dataInfo.updatePreds(nid, weight, param.learningRate) - } else { - val weights = node.calcWeights(param) - dataInfo.updatePreds(nid, weights, param.learningRate) - } - } - - def finishTree(): Unit = { - forest.last.getNodes.asScala.foreach { - case (nid, node) => - if (node.getSplitEntry == null && !node.isLeaf) - setAsLeaf(nid, node) - } - for (i <- storedHists.indices) - storedHists(i) = null - } - - def evaluate(): Seq[(EvalMetric.Kind, Double, Double)] = { - for (i <- validData.indices) { - var node = forest.last.getRoot - while (!node.isLeaf) { - if (node.getSplitEntry.flowTo(validData(i)) == 0) - node = node.getLeftChild.asInstanceOf[GBTNode] - else - node = node.getRightChild.asInstanceOf[GBTNode] - } - if (param.numClass == 2) { - validPreds(i) += node.getWeight * param.learningRate - } else { - val weights = node.getWeights - for (k <- 0 until param.numClass) - validPreds(i * param.numClass + k) += weights(k) * param.learningRate - } - } - - val metrics = evalMetrics.map(evalMetric => - (evalMetric.getKind, evalMetric.eval(dataInfo.predictions, labels), - evalMetric.eval(validPreds, validLabels)) - ) - - val evalTrainMsg = metrics.map(metric => s"${metric._1}[${metric._2}]").mkString(", ") - println(s"Evaluation on train data after ${forest.size} tree(s): $evalTrainMsg") - val evalValidMsg = metrics.map(metric => s"${metric._1}[${metric._3}]").mkString(", ") - println(s"Evaluation on valid data after ${forest.size} tree(s): $evalValidMsg") - metrics - } - - def finalizeModel(): Seq[GBTTree] = { - histBuilder.shutdown() - forest - } - -} diff --git a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/learner/SparkFPGBDTTrainer.scala b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/learner/SparkFPGBDTTrainer.scala deleted file mode 100644 index ad8fd4feb..000000000 --- a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/learner/SparkFPGBDTTrainer.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Tencent is pleased to support the open source community by making Angel available. - * - * Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in - * compliance with the License. You may obtain a copy of the License at - * - * https://opensource.org/licenses/Apache-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License - * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express - * or implied. See the License for the specific language governing permissions and limitations under - * the License. - * - */ - - -package com.tencent.angel.spark.ml.tree.gbdt.learner - -import com.tencent.angel.spark.ml.tree.data.{Instance, VerticalPartition => VP} -import com.tencent.angel.spark.ml.tree.gbdt.metadata.FeatureInfo -import com.tencent.angel.spark.ml.tree.gbdt.tree.GBTSplit -import com.tencent.angel.spark.ml.tree.param.GBDTParam -import com.tencent.angel.spark.ml.tree.util.{DataLoader, Maths, Transposer} -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.{SparkContext, TaskContext} - -class SparkFPGBDTTrainer(param: GBDTParam) extends Serializable { - @transient implicit val sc = SparkContext.getOrCreate() - - @transient private var workers: RDD[FPGBDTLearner] = _ - - def initialize(trainInput: String, validInput: String): Unit = { - val bcParam = sc.broadcast(param) - - val loadStart = System.currentTimeMillis() - val train = DataLoader.loadLibsvmFP(trainInput, - param.numFeature, param.numWorker) - .persist(StorageLevel.MEMORY_AND_DISK) - val valid = DataLoader.loadLibsvmDP(validInput, param.numFeature) - .repartition(param.numWorker) - .persist(StorageLevel.MEMORY_AND_DISK) - val numTrain = train.map(_.labels.length).reduce(_ + _) / param.numWorker - val numValid = valid.count() - println(s"load data cost ${System.currentTimeMillis() - loadStart} ms, " + - s"$numTrain train data, $numValid valid data") - - val createFIStart = System.currentTimeMillis() - val splits = new Array[Array[Float]](param.numFeature) - train.mapPartitions(iterator => - VP.getCandidateSplits(iterator.toSeq, - bcParam.value.numFeature, bcParam.value.numSplit).iterator - ).collect().foreach { - case (fid, fSplits) => splits(fid) = fSplits - } - val featureInfo = FeatureInfo(param.numFeature, splits) - val bcFeatureInfo = sc.broadcast(featureInfo) - println(s"Create feature info cost ${System.currentTimeMillis() - createFIStart} ms") - - val initStart = System.currentTimeMillis() - val workers = train.zipPartitions(valid)( - (vpIter, validIter) => { - val (trainLabels, trainData) = VP.discretize(vpIter.toSeq, bcFeatureInfo.value) - val valid = validIter.toArray - val validLabels = valid.map(_.label.toFloat) - val validData = valid.map(_.feature) - Instance.ensureLabel(trainLabels, bcParam.value.numClass) - Instance.ensureLabel(validLabels, bcParam.value.numClass) - val worker = new FPGBDTLearner(TaskContext.getPartitionId, - bcParam.value, bcFeatureInfo.value, - trainData, trainLabels, validData, validLabels) - Iterator(worker) - } - ).cache() - workers.foreach(worker => - println(s"Worker[${worker.learnerId}] initialization done. " + - s"Hyper-parameters:\n$param") - ) - println(s"Initialize workers cost ${System.currentTimeMillis() - initStart} ms") - - train.unpersist() - valid.unpersist() - this.workers = workers - } - - def loadData(input: String, validRatio: Double): Unit = { - val loadStart = System.currentTimeMillis() - val data = DataLoader.loadLibsvmDP(input, param.numFeature) - .repartition(param.numWorker) - .persist(StorageLevel.MEMORY_AND_DISK) - val splits = data.randomSplit(Array(1.0 - validRatio, validRatio)) - val train = splits(0).cache() - val valid = splits(1).cache() - - val numTrain = train.count() - val numValid = valid.count() - data.unpersist() - println(s"load data cost ${System.currentTimeMillis() - loadStart} ms, " + - s"$numTrain train data, $numValid valid data") - - val initStart = System.currentTimeMillis() - val transposer = new Transposer() - val (trainData, labels, bcFeatureInfo) = transposer.transpose2(train, - param.numFeature, param.numWorker, param.numSplit) - Instance.ensureLabel(labels, param.numClass) - val bcLabels = sc.broadcast(labels) - - val bcParam = sc.broadcast(param) - val workers = trainData.zipPartitions(valid)( - (trainIter, validIter) => { - val learnerId = TaskContext.getPartitionId - val valid = validIter.toArray - val trainData = trainIter.toArray - val trainLabels = bcLabels.value - val validData = valid.map(_.feature) - val validLabels = valid.map(_.label.toFloat) - Instance.ensureLabel(validLabels, bcParam.value.numClass) - val worker = new FPGBDTLearner(learnerId, bcParam.value, bcFeatureInfo.value, - null, trainLabels, validData, validLabels) - Iterator(worker) - } - ).cache() - workers.foreach(worker => - println(s"Worker[${worker.learnerId}] initialization done. " + - s"Hyper-parameters:\n$param") - ) - - train.unpersist() - valid.unpersist() - this.workers = workers - println(s"Transpose data and initialize workers cost ${System.currentTimeMillis() - initStart} ms") - } - - def train(): Unit = { - val trainStart = System.currentTimeMillis() - - for (treeId <- 0 until param.numTree) { - println(s"Start to train tree ${treeId + 1}") - - // 1. create new tree - val createStart = System.currentTimeMillis() - workers.foreach(_.createNewTree()) - val bestSplits = new Array[GBTSplit](Maths.pow(2, param.maxDepth) - 1) - println(s"Tree[${treeId + 1}] Create new tree cost ${System.currentTimeMillis() - createStart} ms") - - var hasActive = true - while (hasActive) { - // 2. build histograms and find local best splits - val findStart = System.currentTimeMillis() - val nids = collection.mutable.TreeSet[Int]() - workers.flatMap(_.findSplits().iterator) - .collect() - .foreach { - case (nid, split) => - nids += nid - if (bestSplits(nid) == null) - bestSplits(nid) = split - else - bestSplits(nid).update(split) - } - val validSplits = nids.toArray.map(nid => (nid, bestSplits(nid))) - println(s"Build histograms and find best splits cost " + - s"${System.currentTimeMillis() - findStart} ms, " + - s"${validSplits.length} node(s) to split") - if (validSplits.nonEmpty) { - // 3. get split results - val resultStart = System.currentTimeMillis() - val bcSplits = sc.broadcast(validSplits) - val splitResults = workers.flatMap( - _.getSplitResults(bcSplits.value).iterator - ).collect() - val bcSplitResults = sc.broadcast(splitResults) - println(s"Get split results cost ${System.currentTimeMillis() - resultStart} ms") - // 4. split nodes - val splitStart = System.currentTimeMillis() - hasActive = workers.map(_.splitNodes(bcSplitResults.value)).collect()(0) - println(s"Split nodes cost ${System.currentTimeMillis() - splitStart} ms") - } else { - // no active nodes - hasActive = false - } - } - - // 5. finish tree - val finishStart = System.currentTimeMillis() - val metrics = workers.map(worker => { - worker.finishTree() - worker.evaluate() - }).collect()(0) - val evalTrainMsg = metrics.map(metric => s"${metric._1}[${metric._2}]").mkString(", ") - println(s"Evaluation on train data after ${treeId + 1} tree(s): $evalTrainMsg") - val evalValidMsg = metrics.map(metric => s"${metric._1}[${metric._3}]").mkString(", ") - println(s"Evaluation on valid data after ${treeId + 1} tree(s): $evalValidMsg") - println(s"Tree[${treeId + 1}] Finish tree cost ${System.currentTimeMillis() - finishStart} ms") - - val currentTime = System.currentTimeMillis() - println(s"Train tree cost ${currentTime - createStart} ms, " + - s"${treeId + 1} tree(s) done, ${currentTime - trainStart} ms elapsed") - - workers.map(_.reportTime()).collect().zipWithIndex.foreach { - case (str, id) => - println(s"========Time cost summation of worker[$id]========") - println(str) - } - } - - // TODO: check equality - val forest = workers.map(_.finalizeModel()).collect()(0) - forest.zipWithIndex.foreach { - case (tree, treeId) => - println(s"Tree[${treeId + 1}] contains ${tree.size} nodes " + - s"(${(tree.size - 1) / 2 + 1} leaves)") - } - } -} diff --git a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/metadata/FeatureInfo.scala b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/metadata/FeatureInfo.scala index 0b039d28d..5b917ab01 100644 --- a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/metadata/FeatureInfo.scala +++ b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/metadata/FeatureInfo.scala @@ -22,7 +22,8 @@ import com.tencent.angel.spark.ml.tree.util.Maths object FeatureInfo { - val ENUM_THRESHOLD: Int = 32 + + val ENUM_THRESHOLD: Int = 10 def apply(numFeature: Int, splits: Array[Array[Float]]): FeatureInfo = { require(splits.length == numFeature) diff --git a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/predictor/GBDTPredictor.scala b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/predictor/GBDTPredictor.scala index 8d6635cce..676b4a48a 100644 --- a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/predictor/GBDTPredictor.scala +++ b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/predictor/GBDTPredictor.scala @@ -18,10 +18,10 @@ package com.tencent.angel.spark.ml.tree.gbdt.predictor -import com.tencent.angel.spark.ml.tree.common.TreeConf._ +import com.tencent.angel.conf.AngelConf +import com.tencent.angel.spark.ml.core.ArgsUtil import com.tencent.angel.spark.ml.tree.data.Instance import com.tencent.angel.spark.ml.tree.gbdt.tree.{GBTNode, GBTTree} -import com.tencent.angel.spark.ml.tree.objective.ObjectiveFactory import com.tencent.angel.spark.ml.tree.util.DataLoader import org.apache.hadoop.fs.Path import org.apache.spark.ml.linalg.{Vector, Vectors} @@ -37,16 +37,18 @@ class GBDTPredictor extends Serializable { println(s"Reading model from $modelPath") } - def predict(predictor: GBDTPredictor, instances: RDD[Instance]): RDD[(Long, Array[Float])] = { + def predict(predictor: GBDTPredictor, instances: RDD[Instance]): RDD[(Long, Int, Array[Float])] = { val bcPredictor = instances.sparkContext.broadcast(predictor) instances.map { instance => - (instance.label.toLong, bcPredictor.value.predictRaw(instance.feature)) + val predProbs = bcPredictor.value.predictRaw(instance.feature) + val predClass = bcPredictor.value.probToClass(predProbs) + (instance.label.toLong, predClass, predProbs) } } - def predict(implicit sc: SparkContext, validPath: String, predPath: String): Unit = { - println(s"Predicting dataset $validPath") - val instances: RDD[Instance] = DataLoader.loadLibsvmDP(validPath, forest.head.getParam.numFeature).cache() + def predict(implicit sc: SparkContext, predictPath: String, outputPath: String): Unit = { + println(s"Predicting dataset: $predictPath") + val instances: RDD[Instance] = DataLoader.loadLibsvmDP(predictPath, forest.head.getParam.numFeature).cache() //val labels = instances.map(_.label.toFloat).collect() val preds = predict(this, instances) instances.unpersist() @@ -62,12 +64,15 @@ class GBDTPredictor extends Serializable { // s"$kind[$metric]" // }).mkString(", ")) - val path = new Path(predPath) + val path = new Path(outputPath) val fs = path.getFileSystem(sc.hadoopConfiguration) if (fs.exists(path)) fs.delete(path, true) - preds.map(pred => s"${pred._1} ${pred._2.mkString(",")}").saveAsTextFile(predPath) - println(s"Writing predictions to $predPath") + if (forest.head.getParam.isClassification) + preds.map(pred => s"${pred._1} ${pred._2} ${pred._3.mkString(",")}").saveAsTextFile(outputPath) + else + preds.map(pred => s"${pred._1} ${pred._3.mkString(",")}").saveAsTextFile(outputPath) + println(s"Writing predictions to $outputPath") } def predictRaw(vec: Vector): Array[Float] = { @@ -82,7 +87,7 @@ class GBDTPredictor extends Serializable { else node = node.getRightChild.asInstanceOf[GBTNode] } - if (param.numClass == 2) { + if (param.isRegression || param.numClass == 2) { preds(0) += node.getWeight * param.learningRate } else { if (param.isMultiClassMultiTree) { @@ -101,7 +106,7 @@ class GBDTPredictor extends Serializable { def probToClass(preds: Array[Float]): Int = { preds.length match { case 1 => if (preds.head > 0.5) 1 else 0 - case _ => preds.zipWithIndex.maxBy(_._1)._2 + 1 + case _ => preds.zipWithIndex.maxBy(_._1)._2 } } @@ -154,13 +159,20 @@ object GBDTPredictor { @transient val conf = new SparkConf() @transient implicit val sc = SparkContext.getOrCreate(conf) - val modelPath = conf.get(ML_MODEL_PATH) - val validPath = conf.get(ML_VALID_PATH) - val predictPath = conf.get(ML_PREDICT_PATH) + val params = ArgsUtil.parse(args) + + //val modelPath = params.getOrElse(TreeConf.ML_MODEL_PATH, "xxx") + val modelPath = params.getOrElse(AngelConf.ANGEL_LOAD_MODEL_PATH, "xxx") + + //val predictPath = params.getOrElse(TreeConf.ML_PREDICT_PATH, "xxx") + val predictPath = params.getOrElse(AngelConf.ANGEL_PREDICT_DATA_PATH, "xxx") + + //val outputPath = params.getOrElse(TreeConf.ML_OUTPUT_PATH, "xxx") + val outputPath = params.getOrElse(AngelConf.ANGEL_PREDICT_PATH, "xxx") val predictor = new GBDTPredictor predictor.loadModel(sc, modelPath) - predictor.predict(sc, validPath, predictPath) + predictor.predict(sc, predictPath, outputPath) sc.stop } diff --git a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/trainer/GBDTTrainer.scala b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/trainer/GBDTTrainer.scala index 3a2c2f352..445711935 100644 --- a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/trainer/GBDTTrainer.scala +++ b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/gbdt/trainer/GBDTTrainer.scala @@ -22,7 +22,6 @@ import com.tencent.angel.conf.AngelConf import com.tencent.angel.ml.core.conf.{MLConf, SharedConf} import com.tencent.angel.spark.ml.core.ArgsUtil import com.tencent.angel.spark.ml.tree.param.GBDTParam -import com.tencent.angel.spark.ml.tree.common.TreeConf._ import com.tencent.angel.spark.ml.tree.gbdt.dataset.Dataset import com.tencent.angel.spark.ml.tree.gbdt.dataset.Dataset._ import com.tencent.angel.spark.ml.tree.data.Instance @@ -66,6 +65,7 @@ object GBDTTrainer { val params = ArgsUtil.parse(args) // dataset conf + param.taskType = params.getOrElse(MLConf.ML_GBDT_TASK_TYPE, MLConf.DEFAULT_ML_GBDT_TASK_TYPE) param.numClass = params.getOrElse(MLConf.ML_NUM_CLASS, "2").toInt param.numFeature = params.getOrElse(MLConf.ML_FEATURE_INDEX_RANGE, "-1").toInt SharedConf.get().setInt(MLConf.ML_NUM_CLASS, param.numClass) @@ -75,9 +75,18 @@ object GBDTTrainer { param.lossFunc = params.getOrElse(MLConf.ML_GBDT_LOSS_FUNCTION, "binary:logistic") param.evalMetrics = params.getOrElse(MLConf.ML_GBDT_EVAL_METRIC, "error").split(",").map(_.trim).filter(_.nonEmpty) SharedConf.get().set(MLConf.ML_GBDT_LOSS_FUNCTION, param.lossFunc) - param.multiStrategy = params.getOrElse("ml.gbdt.multi.class.strategy", "one-tree") - if (param.isMultiClassMultiTree) param.lossFunc = "binary:logistic" - param.multiGradCache = params.getOrElse("ml.gbdt.multi.class.grad.cache", "true").toBoolean + + param.taskType match { + case "regression" => + require(param.lossFunc.equals("rmse") && param.evalMetrics(0).equals("rmse"), + "loss function and metric of regression task should be rmse") + param.numClass = 2 + case "classification" => + require(param.numClass >= 2, "number of labels should be larger than 2") + param.multiStrategy = params.getOrElse("ml.gbdt.multi.class.strategy", "one-tree") + if (param.isMultiClassMultiTree) param.lossFunc = "binary:logistic" + param.multiGradCache = params.getOrElse("ml.gbdt.multi.class.grad.cache", "true").toBoolean + } // major algo conf param.featSampleRatio = params.getOrElse(MLConf.ML_GBDT_FEATURE_SAMPLE_RATIO, "1.0").toFloat @@ -243,7 +252,7 @@ class GBDTTrainer(param: GBDTParam) extends Serializable { Array.copy(partLabel, 0, labels, offset, partLabel.length) offset += partLabel.length }) - val changeLabel = Instance.ensureLabel(labels, param.numClass) + val changeLabel = if (param.isClassification) Instance.ensureLabel(labels, param.numClass) else false val bcLabels = sc.broadcast(labels) val bcChangeLabel = sc.broadcast(changeLabel) println(s"Collect labels cost ${System.currentTimeMillis() - labelStart} ms") @@ -304,10 +313,19 @@ class GBDTTrainer(param: GBDTParam) extends Serializable { }) if (merged(i) != null && !merged(i).isEmpty) { val distinct = merged(i).tryDistinct(FeatureInfo.ENUM_THRESHOLD) - if (distinct == null) - (false, Maths.unique(merged(i).getQuantiles(numSplit)), merged(i).getN.toInt) - else + if (distinct == null) { + val tmpSplits = Maths.unique(merged(i).getQuantiles(numSplit)) + if (tmpSplits.length == 1 && tmpSplits(0) > 0) { + (false, Array(0, tmpSplits(0)), merged(i).getN.toInt) + } else if (tmpSplits.length == 1 && tmpSplits(0) < 0) { + (false, Array(tmpSplits(0), 0), merged(i).getN.toInt) + } else { + (false, tmpSplits, merged(i).getN.toInt) + } + } + else { (true, distinct, merged(i).getN.toInt) + } } else { (false, null, 0) } @@ -411,7 +429,7 @@ class GBDTTrainer(param: GBDTParam) extends Serializable { val loss = ObjectiveFactory.getLoss(param.lossFunc) val evalMetrics = ObjectiveFactory.getEvalMetricsOrDefault(param.evalMetrics, loss) - val multiStrategy = ObjectiveFactory.getMultiStrategy(param.multiStrategy) + //val multiStrategy = ObjectiveFactory.getMultiStrategy(param.multiStrategy) LogHelper.setLogLevel("info") diff --git a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/param/GBDTParam.java b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/param/GBDTParam.java index 958bd845b..b20464e82 100644 --- a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/param/GBDTParam.java +++ b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/param/GBDTParam.java @@ -23,6 +23,7 @@ public class GBDTParam extends RegTParam { + public String taskType; // classification or regression public int numClass; // number of classes/labels public int numTree; // number of trees public int numThread; // parallelism @@ -43,6 +44,14 @@ public class GBDTParam extends RegTParam { public String multiStrategy; // strategy of multi-class classification (one-tree or multi-tree) public boolean multiGradCache; // use grad cache for multiclass-multitree, or calc grad for every tree + public boolean isClassification() { + return taskType.equalsIgnoreCase("classification"); + } + + public boolean isRegression() { + return taskType.equalsIgnoreCase("regression"); + } + public int numClassPerTree() { if (numClass > 2 && multiStrategy.equalsIgnoreCase(MultiStrategy.ONE_TREE.toString())) { return numClass; diff --git a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/sketch/HeapQuantileSketch.java b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/sketch/HeapQuantileSketch.java index 735ed2b26..692c53299 100644 --- a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/sketch/HeapQuantileSketch.java +++ b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/sketch/HeapQuantileSketch.java @@ -398,11 +398,11 @@ public float[] tryDistinct(int maxItemNums) { } if (samplesArr[i] != samplesArr[i - 1]) { cnt++; - if (cnt++ > maxItemNums) { - return null; - } } } + if (cnt > maxItemNums) { + return null; + } if (cnt != samplesArr.length) { float[] res = new float[cnt]; res[0] = samplesArr[0]; diff --git a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/util/DatasetAnalysis.scala b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/util/DatasetAnalysis.scala index 2c68f076e..f41e31c14 100644 --- a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/util/DatasetAnalysis.scala +++ b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/util/DatasetAnalysis.scala @@ -17,7 +17,7 @@ package com.tencent.angel.spark.ml.tree.util -import com.tencent.angel.spark.ml.tree.common.TreeConf._ +import com.tencent.angel.spark.ml.tree.common.TreeConf import com.tencent.angel.spark.ml.tree.gbdt.metadata.FeatureInfo import com.tencent.angel.spark.ml.tree.sketch.HeapQuantileSketch import org.apache.spark.{SparkConf, SparkContext} @@ -42,10 +42,10 @@ object DatasetAnalysis { } def analysis(conf: SparkConf)(implicit sc: SparkContext): Unit = { - val input = conf.get(ML_TRAIN_PATH) - val dim = conf.get(ML_NUM_FEATURE).toInt - val numWorker = conf.get(ML_NUM_WORKER).toInt - val numSplit = conf.getInt(ML_GBDT_SPLIT_NUM, DEFAULT_ML_GBDT_SPLIT_NUM) + val input = conf.get(TreeConf.ML_TRAIN_PATH) + val dim = conf.get(TreeConf.ML_NUM_FEATURE).toInt + val numWorker = conf.get(TreeConf.ML_NUM_WORKER).toInt + val numSplit = conf.getInt(TreeConf.ML_GBDT_SPLIT_NUM, TreeConf.DEFAULT_ML_GBDT_SPLIT_NUM) val loadStart = System.currentTimeMillis() @@ -167,7 +167,7 @@ object DatasetAnalysis { } def change_label(conf: SparkConf)(implicit sc: SparkContext): Unit = { - val input = conf.get(ML_TRAIN_PATH) + val input = conf.get(TreeConf.ML_TRAIN_PATH) val output = conf.get("spark.ml.output.path") sc.textFile(input) @@ -183,9 +183,9 @@ object DatasetAnalysis { } def coalesce_label(conf: SparkConf)(implicit sc: SparkContext): Unit = { - val input = conf.get(ML_TRAIN_PATH) + val input = conf.get(TreeConf.ML_TRAIN_PATH) val output = conf.get("spark.ml.output.path") - val numClass = conf.get(ML_NUM_CLASS).toInt + val numClass = conf.get(TreeConf.ML_NUM_CLASS).toInt val coalescedNumClass = conf.get("spark.ml.coalesced.class.num").toInt val avg = if (numClass % coalescedNumClass > numClass / 2) { @@ -208,9 +208,9 @@ object DatasetAnalysis { } def shuffle_feature(conf: SparkConf)(implicit sc: SparkContext): Unit = { - val input = conf.get(ML_TRAIN_PATH) + val input = conf.get(TreeConf.ML_TRAIN_PATH) val output = conf.get("spark.ml.output.path") - val numFeature = conf.get(ML_NUM_FEATURE).toInt + val numFeature = conf.get(TreeConf.ML_NUM_FEATURE).toInt val shuffle = (0 until numFeature).toArray Maths.shuffle(shuffle) val bcShuffle = sc.broadcast(shuffle) diff --git a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/util/LogHelper.scala b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/util/LogHelper.scala index e918dce53..df6f90cda 100644 --- a/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/util/LogHelper.scala +++ b/spark-on-angel/mllib/src/main/scala/com/tencent/angel/spark/ml/tree/util/LogHelper.scala @@ -19,7 +19,7 @@ package com.tencent.angel.spark.ml.tree.util object LogHelper { - var LOG_LEVEL: String = "info" + var LOG_LEVEL: String = "debug" def setLogLevel(level: String): Unit = { assert(level.equalsIgnoreCase("debug") || level.equalsIgnoreCase("info")) diff --git a/spark-on-angel/mllib/src/test/scala/com/tencent/angel/spark/ml/GBDTTest.scala b/spark-on-angel/mllib/src/test/scala/com/tencent/angel/spark/ml/GBDTTest.scala index cc986cf0f..876a29150 100644 --- a/spark-on-angel/mllib/src/test/scala/com/tencent/angel/spark/ml/GBDTTest.scala +++ b/spark-on-angel/mllib/src/test/scala/com/tencent/angel/spark/ml/GBDTTest.scala @@ -1,6 +1,5 @@ package com.tencent.angel.spark.ml -import com.tencent.angel.spark.ml.tree.common.TreeConf._ import com.tencent.angel.spark.ml.tree.gbdt.predictor.GBDTPredictor import com.tencent.angel.spark.ml.tree.gbdt.trainer.GBDTTrainer import com.tencent.angel.spark.ml.tree.param.GBDTParam @@ -17,48 +16,37 @@ class GBDTTest extends PSFunSuite with SharedPSContext { override def beforeAll(): Unit = { + super.beforeAll() + trainPath = "../../data/agaricus/agaricus_127d_train.libsvm" - conf.set(ML_TRAIN_PATH, trainPath) testPath = "../../data/agaricus/agaricus_127d_train.libsvm" - conf.set(ML_VALID_PATH, testPath) modelPath = "../../tmp/model" - conf.set(ML_MODEL_PATH, modelPath) predPath = "../../tmp/pred" - conf.set(ML_MODEL_PATH, modelPath) - - conf.set(ML_NUM_CLASS, "2") - conf.set(ML_NUM_FEATURE, "149") - conf.set(ML_NUM_WORKER, "1") - conf.set(ML_LOSS_FUNCTION, "binary:logistic") - conf.set(ML_EVAL_METRIC, "error,auc") - conf.set(ML_LEARN_RATE, "0.1") - conf.set(ML_GBDT_MAX_DEPTH, "3") - - super.beforeAll() val param = new GBDTParam - param.numClass = conf.getInt(ML_NUM_CLASS, DEFAULT_ML_NUM_CLASS) - param.numFeature = conf.get(ML_NUM_FEATURE).toInt - param.featSampleRatio = conf.getDouble(ML_FEATURE_SAMPLE_RATIO, DEFAULT_ML_FEATURE_SAMPLE_RATIO).toFloat - param.numWorker = conf.get(ML_NUM_WORKER).toInt - param.numThread = conf.getInt(ML_NUM_THREAD, DEFAULT_ML_NUM_THREAD) - param.lossFunc = conf.get(ML_LOSS_FUNCTION) - param.evalMetrics = conf.get(ML_EVAL_METRIC, DEFAULT_ML_EVAL_METRIC).split(",").map(_.trim).filter(_.nonEmpty) - param.learningRate = conf.getDouble(ML_LEARN_RATE, DEFAULT_ML_LEARN_RATE).toFloat - param.histSubtraction = conf.getBoolean(ML_GBDT_HIST_SUBTRACTION, DEFAULT_ML_GBDT_HIST_SUBTRACTION) - param.lighterChildFirst = conf.getBoolean(ML_GBDT_LIGHTER_CHILD_FIRST, DEFAULT_ML_GBDT_LIGHTER_CHILD_FIRST) - param.fullHessian = conf.getBoolean(ML_GBDT_FULL_HESSIAN, DEFAULT_ML_GBDT_FULL_HESSIAN) - param.numSplit = conf.getInt(ML_GBDT_SPLIT_NUM, DEFAULT_ML_GBDT_SPLIT_NUM) - param.numTree = conf.getInt(ML_GBDT_TREE_NUM, DEFAULT_ML_GBDT_TREE_NUM) - param.maxDepth = conf.getInt(ML_GBDT_MAX_DEPTH, DEFAULT_ML_GBDT_MAX_DEPTH) + param.taskType = "classification" + param.numClass = 2 + param.numFeature = 149 + param.featSampleRatio = 1.0f + param.numWorker = 1 + param.numThread = 1 + param.lossFunc = "binary:logistic" + param.evalMetrics = "error,auc".split(",").map(_.trim).filter(_.nonEmpty) + param.learningRate = 0.1f + param.histSubtraction = true + param.lighterChildFirst = true + param.fullHessian = false + param.numSplit = 10 + param.numTree = 20 + param.maxDepth = 4 val maxNodeNum = Maths.pow(2, param.maxDepth + 1) - 1 - param.maxNodeNum = conf.getInt(ML_GBDT_MAX_NODE_NUM, maxNodeNum) min maxNodeNum - param.minChildWeight = conf.getDouble(ML_GBDT_MIN_CHILD_WEIGHT, DEFAULT_ML_GBDT_MIN_CHILD_WEIGHT).toFloat - param.minNodeInstance = conf.getInt(ML_GBDT_MIN_NODE_INSTANCE, DEFAULT_ML_GBDT_MIN_NODE_INSTANCE) - param.minSplitGain = conf.getDouble(ML_GBDT_MIN_SPLIT_GAIN, DEFAULT_ML_GBDT_MIN_SPLIT_GAIN).toFloat - param.regAlpha = conf.getDouble(ML_GBDT_REG_ALPHA, DEFAULT_ML_GBDT_REG_ALPHA).toFloat - param.regLambda = conf.getDouble(ML_GBDT_REG_LAMBDA, DEFAULT_ML_GBDT_REG_LAMBDA).toFloat max 1.0f - param.maxLeafWeight = conf.getDouble(ML_GBDT_MAX_LEAF_WEIGHT, DEFAULT_ML_GBDT_MAX_LEAF_WEIGHT).toFloat + param.maxNodeNum = maxNodeNum + param.minChildWeight = 0.01f + param.minNodeInstance = 10 + param.minSplitGain = 0.0f + param.regAlpha = 0.0f + param.regLambda = 0.1f + param.maxLeafWeight = 0.0f println(s"Hyper-parameters:\n$param") trainer = new GBDTTrainer(param)