Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve GBDT: regression, predict, docs #736

Merged
merged 6 commits into from
Apr 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions docs/algo/sona/feature_gbdt_sona.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,42 +55,75 @@ GBDT的训练方法中,核心是一种叫梯度直方图的数据结构,需
### 参数

* **算法参数**
* ml.num.class:分裂数量
* ml.gbdt.task.type:任务类型,分类或者回归
* ml.gbdt.loss.func:代价函数,支持二分类(binary:logistic)、多分类(multi:logistic)和均方根误差(rmse)
* ml.gbdt.eval.metric:模型指标,支持rmse、error、log-loss、cross-entropy、precision和auc
* ml.num.class:分类数量,仅对分类任务有用
* ml.gbdt.feature.sample.ratio:特征采样比例(0到1之间)
* ml.gbdt.tree.num:树的数量
* ml.gbdt.tree.depth:树的最大高度
* ml.gbdt.split.num:每个特征的分裂点的数量
* ml.learn.rate:学习速率
* ml.gbdt.min.node.instance:叶子节点上数据的最少数量
* ml.gbdt.min.split.gain:分裂需要的最小增益
* ml.gbdt.reg.lambda:正则化系数
* ml.gbdt.multi.class.strategy:多分类任务的策略,一轮一棵树(one-tree)或者一轮多棵树(multi-tree)

* **输入输出参数**
* angel.train.data.path:训练数据的输入路径
* angel.validate.data.path:验证数据的输入路径
* angel.predict.data.path:预测数据的输入路径
* angel.predict.out.path:预测结果的保存路径
* angel.save.model.path:训练完成后,模型的保存路径
* angel.load.model.path:预测开始前,模型的加载路径

### 训练任务启动命令示例

使用spark提交任务

./spark-submit \
--master yarn-cluster \
--conf spark.ps.jars=$SONA_ANGEL_JARS \
--jars $SONA_SPARK_JARS \
--name "LR Adam on Spark-on-Angel" \
--conf spark.ps.cores=1 \
--conf spark.ps.memory=10g \
--conf spark.ps.log.level=INFO \
--queue $queue \
--jars $SONA_SPARK_JARS \
--name "GBDT on Spark-on-Angel" \
--driver-memory 5g \
--num-executors 10 \
--executor-cores 1 \
--executor-memory 10g \
--class com.tencent.angel.spark.ml.tree.gbdt.trainer.GBDTTrainer \
spark-on-angel-mllib-${ANGEL_VERSION}.jar \
ml.gbdt.task.type:classification \
angel.train.data.path:XXX angel.validate.data.path:XXX angel.save.model.path:XXX \
ml.gbdt.loss.func:binary:logistic ml.gbdt.eval.metric:error,log-loss \
ml.learn.rate:0.1 ml.gbdt.split.num:10 ml.gbdt.tree.num:20 ml.gbdt.tree.depth:7 ml.class.num:2 \
ml.feature.index.range:47237 ml.gbdt.feature.sample.ratio:1.0
ml.learn.rate:0.1 ml.gbdt.split.num:10 ml.gbdt.tree.num:20 ml.gbdt.tree.depth:7 ml.num.class:2 \
ml.feature.index.range:47237 ml.gbdt.feature.sample.ratio:1.0 ml.gbdt.multi.class.strategy:one-tree ml.gbdt.min.node.instance:100

### 预测任务启动命令示例

使用spark提交任务

./spark-submit \
--master yarn-cluster \
--conf spark.ps.jars=$SONA_ANGEL_JARS \
--conf spark.ps.cores=1 \
--conf spark.ps.memory=10g \
--conf spark.ps.log.level=INFO \
--queue $queue \
--jars $SONA_SPARK_JARS \
--name "GBDT on Spark-on-Angel" \
--driver-memory 5g \
--num-executors 10 \
--executor-cores 1 \
--executor-memory 10g \
--class com.tencent.angel.spark.ml.tree.gbdt.predictor.GBDTPredictor \
spark-on-angel-mllib-${ANGEL_VERSION}.jar \
angel.load.model.path:XXX angel.predict.data.path:XXX angel.predict.out.path:XXX \



## 5. 性能

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ public class TreeConf {
public static final String ML_TRAIN_PATH = "spark.ml.train.path";
public static final String ML_VALID_PATH = "spark.ml.valid.path";
public static final String ML_PREDICT_PATH = "spark.ml.predict.path";
public static final String ML_OUTPUT_PATH = "spark.ml.output.path";

public static final String ML_GBDT_TASK_TYPE = "ml.gbdt.task.type";
public static final String DEFAULT_ML_GBDT_TASK_TYPE = "classification";
public static final String ML_VALID_DATA_RATIO = "spark.ml.valid.ratio";
public static final double DEFAULT_ML_VALID_DATA_RATIO = 0.25;
public static final String ML_NUM_CLASS = "spark.ml.class.num";
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/*
* Tencent is pleased to support the open source community by making Angel available.
*
* Copyright (C) 2017-2018 THL A29 Limited, a Tencent company. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in
* compliance with the License. You may obtain a copy of the License at
*
* https://opensource.org/licenses/Apache-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License
* is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. See the License for the specific language governing permissions and limitations under
* the License.
*
*/

package com.tencent.angel.spark.ml.tree.gbdt.examples

import com.tencent.angel.spark.ml.core.ArgsUtil
import com.tencent.angel.spark.ml.tree.gbdt.trainer.GBDTTrainer
import com.tencent.angel.spark.ml.tree.param.GBDTParam
import com.tencent.angel.spark.ml.tree.util.Maths
import org.apache.spark.{SparkConf, SparkContext}

object GBDTRegression {

def main(args: Array[String]): Unit = {

@transient val conf = new SparkConf().setMaster("local").setAppName("gbdt")

val param = new GBDTParam

// spark conf
val numExecutor = 1
val numCores = 1
param.numWorker = numExecutor
param.numThread = numCores
conf.set("spark.task.cpus", numCores.toString)
conf.set("spark.locality.wait", "0")
conf.set("spark.memory.fraction", "0.7")
conf.set("spark.memory.storageFraction", "0.8")
conf.set("spark.task.maxFailures", "1")
conf.set("spark.yarn.maxAppAttempts", "1")
conf.set("spark.network.timeout", "1000")
conf.set("spark.executor.heartbeatInterval", "500")

val params = ArgsUtil.parse(args)

//val trainPath = "data/dna/dna.scale" //dimension=181
//val validPath = "data/dna/dna.scale.t"
val trainPath = "data/abalone/abalone_8d_train.libsvm" //dimension=8
val validPath = "data/abalone/abalone_8d_train.libsvm"
val modelPath = "tmp/gbdt/abalone"

// dataset conf
param.taskType = "regression"
param.numClass = 2
param.numFeature = 8

// loss and metric
param.lossFunc = "rmse"
param.evalMetrics = Array("rmse")
param.multiGradCache = false

// major algo conf
param.featSampleRatio = 1.0f
param.learningRate = 0.1f
param.numSplit = 10
param.numTree = 10
param.maxDepth = 7
val maxNodeNum = Maths.pow(2, param.maxDepth + 1) - 1
param.maxNodeNum = 4096 min maxNodeNum

// less important algo conf
param.histSubtraction = true
param.lighterChildFirst = true
param.fullHessian = false
param.minChildWeight = 0.0f
param.minNodeInstance = 10
param.minSplitGain = 0.0f
param.regAlpha = 0.0f
param.regLambda = 1.0f
param.maxLeafWeight = 0.0f

println(s"Hyper-parameters:\n$param")

@transient implicit val sc = new SparkContext(conf)

try {
val trainer = new GBDTTrainer(param)
trainer.initialize(trainPath, validPath)
val model = trainer.train()
trainer.save(model, modelPath)
} catch {
case e: Exception =>
e.printStackTrace()
} finally {
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public Histogram(int numBin, int numClass, boolean fullHessian, boolean multiCla
this.numClass = numClass;
this.fullHessian = fullHessian;
this.multiClassMultiTree = multiClassMultiTree;

if (numClass == 2 || multiClassMultiTree) {
this.gradients = new double[numBin];
this.hessians = new double[numBin];
Expand Down
Loading