Skip to content

Commit

Permalink
add seed parameters to lightgbm
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Feb 8, 2022
1 parent e78f582 commit 427a2ee
Show file tree
Hide file tree
Showing 6 changed files with 135 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
ExecutionParams(getChunkSize, getMatrixType, execNumThreads, getUseSingleDatasetMode)
}

/**
* Constructs the ColumnParams.
*
* @return ColumnParams object containing the parameters related to LightGBM columns.
*/
protected def getColumnParams: ColumnParams = {
ColumnParams(getLabelCol, getFeaturesCol, get(weightCol), get(initScoreCol), getOptGroupCol)
}
Expand All @@ -268,6 +273,16 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
ObjectiveParams(getObjective, if (isDefined(fobj)) Some(getFObj) else None)
}

/**
* Constructs the SeedParams.
*
* @return SeedParams object containing the parameters related to LightGBM seeds and determinism.
*/
protected def getSeedParams: SeedParams = {
SeedParams(get(seed), get(deterministic), get(baggingSeed), get(featureFractionSeed),
get(extraSeed), get(dropSeed), get(dataRandomSeed), get(objectiveSeed), boostingType)
}

def getDatasetParams(categoricalIndexes: Array[Int], numThreads: Int): String = {
val datasetParams = s"max_bin=$getMaxBin is_pre_partition=True " +
s"bin_construct_sample_cnt=$getBinSampleCount " +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class LightGBMClassifier(override val uid: String)
getIsUnbalance, getVerbosity, categoricalIndexes, actualNumClasses, getBoostFromAverage,
getBoostingType, get(lambdaL1), get(lambdaL2), get(isProvideTrainingMetric),
get(metric), get(minGainToSplit), get(maxDeltaStep), getMaxBinByFeature, get(minDataInLeaf), getSlotNames,
getDelegate, getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams)
getDelegate, getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams, getSeedParams)
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMClassificationModel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class LightGBMRanker(override val uid: String)
getVerbosity, categoricalIndexes, getBoostingType, get(lambdaL1), get(lambdaL2), getMaxPosition, getLabelGain,
get(isProvideTrainingMetric), get(metric), getEvalAt, get(minGainToSplit), get(maxDeltaStep),
getMaxBinByFeature, get(minDataInLeaf), getSlotNames, getDelegate, getDartParams,
getExecutionParams(numTasksPerExec), getObjectiveParams)
getExecutionParams(numTasksPerExec), getObjectiveParams, getSeedParams)
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRankerModel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class LightGBMRegressor(override val uid: String)
getBoostFromAverage, getBoostingType, get(lambdaL1), get(lambdaL2), get(isProvideTrainingMetric),
get(metric), get(minGainToSplit), get(maxDeltaStep),
getMaxBinByFeature, get(minDataInLeaf), getSlotNames, getDelegate,
getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams)
getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams, getSeedParams)
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRegressionModel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -317,12 +317,71 @@ trait LightGBMObjectiveParams extends Wrappable {
def setFObj(value: FObjTrait): this.type = set(fobj, value)
}

/** Defines common parameters related to seed and determinism
*/
trait LightGBMSeedParams extends Wrappable {
val seed = new IntParam(this, "seed", "Main seed, used to generate other seeds")

def getSeed: Int = $(seed)
def setSeed(value: Int): this.type = set(seed, value)

val deterministic = new BooleanParam(this, "deterministic", "Used only with cpu " +
"devide type. Setting this to true should ensure stable results when using the same data and the " +
"same parameters. Note: setting this to true may slow down training. To avoid potential instability " +
"due to numerical issues, please set force_col_wise=true or force_row_wise=true when setting " +
"deterministic=true")
setDefault(deterministic->false)

def getDeterministic: Boolean = $(deterministic)
def setDeterministic(value: Boolean): this.type = set(deterministic, value)

val baggingSeed = new IntParam(this, "baggingSeed", "Bagging seed")
setDefault(baggingSeed->3)

def getBaggingSeed: Int = $(baggingSeed)
def setBaggingSeed(value: Int): this.type = set(baggingSeed, value)

val featureFractionSeed = new IntParam(this, "featureFractionSeed", "Feature fraction seed")
setDefault(featureFractionSeed->2)

def getFeatureFractionSeed: Int = $(featureFractionSeed)
def setFeatureFractionSeed(value: Int): this.type = set(featureFractionSeed, value)

val extraSeed = new IntParam(this, "extraSeed", "Random seed for selecting threshold " +
"when extra_trees is true")
setDefault(extraSeed->6)

def getExtraSeed: Int = $(extraSeed)
def setExtraSeed(value: Int): this.type = set(extraSeed, value)

val dropSeed = new IntParam(this, "dropSeed", "Random seed to choose dropping models. " +
"Only used in dart.")
setDefault(dropSeed->4)

def getDropSeed: Int = $(dropSeed)
def setDropSeed(value: Int): this.type = set(dropSeed, value)

val dataRandomSeed = new IntParam(this, "dataRandomSeed", "Random seed for sampling " +
"data to construct histogram bins.")
setDefault(dataRandomSeed->1)

def getDataRandomSeed: Int = $(dataRandomSeed)
def setDataRandomSeed(value: Int): this.type = set(dataRandomSeed, value)

val objectiveSeed = new IntParam(this, "objectiveSeed", "Random seed for objectives, " +
"if random process is needed. Currently used only for rank_xendcg objective.")
setDefault(objectiveSeed->5)

def getObjectiveSeed: Int = $(objectiveSeed)
def setObjectiveSeed(value: Int): this.type = set(objectiveSeed, value)
}

/** Defines common parameters across all LightGBM learners.
*/
trait LightGBMParams extends Wrappable with DefaultParamsWritable with HasWeightCol
with HasValidationIndicatorCol with HasInitScoreCol with LightGBMExecutionParams
with LightGBMSlotParams with LightGBMFractionParams with LightGBMBinParams with LightGBMLearnerParams
with LightGBMDartParams with LightGBMPredictionParams with LightGBMObjectiveParams {
with LightGBMDartParams with LightGBMPredictionParams with LightGBMObjectiveParams with LightGBMSeedParams {
val numIterations = new IntParam(this, "numIterations",
"Number of iterations, LightGBM constructs num_class * num_iterations trees")
setDefault(numIterations->100)
Expand All @@ -348,12 +407,6 @@ trait LightGBMParams extends Wrappable with DefaultParamsWritable with HasWeight
def getBaggingFreq: Int = $(baggingFreq)
def setBaggingFreq(value: Int): this.type = set(baggingFreq, value)

val baggingSeed = new IntParam(this, "baggingSeed", "Bagging seed")
setDefault(baggingSeed->3)

def getBaggingSeed: Int = $(baggingSeed)
def setBaggingSeed(value: Int): this.type = set(baggingSeed, value)

val maxDepth = new IntParam(this, "maxDepth", "Max depth")
setDefault(maxDepth-> -1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,22 @@ package com.microsoft.azure.synapse.ml.lightgbm.params

import com.microsoft.azure.synapse.ml.lightgbm.{LightGBMConstants, LightGBMDelegate}

/** Helper utilities for converting params to a string, to be passed to LightGBM. */
object ParamUtils {
def paramToString[T](paramName: String, paramValueOpt: Option[T]): String = {
paramValueOpt match {
case Some(paramValue) => s"$paramName=$paramValue"
case None => ""
}
}

def paramsToString(paramNamesToValues: Array[(String, Option[_])]): String = {
paramNamesToValues.map {
case (paramName: String, paramValue: Option[_]) => paramToString(paramName, paramValue)
}.mkString(" ")
}
}

/** Defines the common Booster parameters passed to the LightGBM learners.
*/
abstract class TrainParams extends Serializable {
Expand Down Expand Up @@ -43,25 +59,13 @@ abstract class TrainParams extends Serializable {
def dartModeParams: DartModeParams
def executionParams: ExecutionParams
def objectiveParams: ObjectiveParams

def paramToString[T](paramName: String, paramValueOpt: Option[T]): String = {
paramValueOpt match {
case Some(paramValue) => s"$paramName=$paramValue"
case None => ""
}
}

def paramsToString(paramNamesToValues: Array[(String, Option[_])]): String = {
paramNamesToValues.map {
case (paramName: String, paramValue: Option[_]) => paramToString(paramName, paramValue)
}.mkString(" ")
}
def seedParams: SeedParams

override def toString: String = {
// Since passing `isProvideTrainingMetric` to LightGBM as a config parameter won't work,
// let's fetch and print training metrics in `TrainUtils.scala` through JNI.
s"is_pre_partition=True boosting_type=$boostingType tree_learner=$parallelism " +
paramsToString(Array(("top_k", topK), ("num_leaves", numLeaves), ("max_bin", maxBin),
ParamUtils.paramsToString(Array(("top_k", topK), ("num_leaves", numLeaves), ("max_bin", maxBin),
("bagging_fraction", baggingFraction), ("pos_bagging_fraction", posBaggingFraction),
("neg_bagging_fraction", negBaggingFraction), ("bagging_freq", baggingFreq),
("bagging_seed", baggingSeed), ("feature_fraction", featureFraction), ("max_depth", maxDepth),
Expand All @@ -75,7 +79,8 @@ abstract class TrainParams extends Serializable {
(if (categoricalFeatures.isEmpty) "" else s"categorical_feature=${categoricalFeatures.mkString(",")} ") +
(if (maxBinByFeature.isEmpty) "" else s"max_bin_by_feature=${maxBinByFeature.mkString(",")} ") +
(if (boostingType == "dart") s"${dartModeParams.toString()} " else "") +
executionParams.toString()
executionParams.toString() +
seedParams.toString()
}
}

Expand Down Expand Up @@ -118,7 +123,8 @@ case class ClassifierTrainParams(parallelism: String,
delegate: Option[LightGBMDelegate],
dartModeParams: DartModeParams,
executionParams: ExecutionParams,
objectiveParams: ObjectiveParams)
objectiveParams: ObjectiveParams,
seedParams: SeedParams)
extends TrainParams {
override def toString: String = {
val extraStr =
Expand Down Expand Up @@ -167,7 +173,8 @@ case class RegressorTrainParams(parallelism: String,
delegate: Option[LightGBMDelegate],
dartModeParams: DartModeParams,
executionParams: ExecutionParams,
objectiveParams: ObjectiveParams)
objectiveParams: ObjectiveParams,
seedParams: SeedParams)
extends TrainParams {
override def toString: String = {
s"alpha=$alpha tweedie_variance_power=$tweedieVariancePower boost_from_average=${boostFromAverage.toString} " +
Expand Down Expand Up @@ -214,7 +221,8 @@ case class RankerTrainParams(parallelism: String,
delegate: Option[LightGBMDelegate],
dartModeParams: DartModeParams,
executionParams: ExecutionParams,
objectiveParams: ObjectiveParams)
objectiveParams: ObjectiveParams,
seedParams: SeedParams)
extends TrainParams {
override def toString: String = {
val labelGainStr =
Expand Down Expand Up @@ -266,3 +274,34 @@ case class ObjectiveParams(objective: String, fobj: Option[FObjTrait]) extends S
}
}
}

/** Defines parameters related to seed and determinism for lightgbm.
*
* @param seed Main seed, used to generate other seeds.
*
* @param deterministic Setting this to true should ensure stable results when using the
* same data and the same parameters.
* @param baggingSeed Bagging seed.
* @param featureFractionSeed Feature fraction seed.
* @param extraSeed Random seed for selecting threshold when extra_trees is true.
* @param dropSeed Random seed to choose dropping models. Only used in dart.
* @param dataRandomSeed Random seed for sampling data to construct histogram bins.
* @param objectiveSeed Random seed for objectives, if random process is needed.
* Currently used only for rank_xendcg objective.
* @param boostingType Boosting type, used to determine if drop seed should be set.
* @param objective Objective, used to determine if objective seed should be set.
*/
case class SeedParams(seed: Option[Int], deterministic: Option[Boolean],
baggingSeed: Option[Int], featureFractionSeed: Option[Int],
extraSeed: Option[Int], dropSeed: Option[Int],
dataRandomSeed: Option[Int], objectiveSeed: Option[Int],
boostingType: String, objective: String) extends Serializable {
override def toString: String = {
ParamUtils.paramsToString(Array(("seed", seed), ("deterministic", deterministic),
("bagging_seed", baggingSeed), ("feature_fraction_seed", featureFractionSeed),
("extra_seed", extraSeed), ("data_random_seed", dataRandomSeed))) +
(if (boostingType == "dart" && dropSeed.isDefined) s"drop_seed=${dropSeed.toString()} " else "") +
(if (objective == "rank_xendcg" && objectiveSeed.isDefined)
s"objective_seed=${objectiveSeed.toString()} " else "")
}
}

0 comments on commit 427a2ee

Please sign in to comment.