Skip to content

Commit

Permalink
fix: lightgbm default params should not be specified if optional (#1232)
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft authored Oct 28, 2021
1 parent 3d92dd7 commit 336eff5
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ class LightGBMClassifier(override val uid: String)
val actualNumClasses = getNumClasses(dataset)
val categoricalIndexes = getCategoricalIndexes(dataset.schema(getFeaturesCol))
val modelStr = if (getModelString == null || getModelString.isEmpty) None else get(modelString)
ClassifierTrainParams(getParallelism, getTopK, getNumIterations, getLearningRate, getNumLeaves, getMaxBin,
getBinSampleCount, getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction,
getBaggingFreq, getBaggingSeed, getEarlyStoppingRound, getImprovementTolerance,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, modelStr,
ClassifierTrainParams(getParallelism, get(topK), getNumIterations, getLearningRate,
get(numLeaves), get(maxBin), get(binSampleCount), get(baggingFraction), get(posBaggingFraction),
get(negBaggingFraction), get(baggingFreq), get(baggingSeed), getEarlyStoppingRound, getImprovementTolerance,
get(featureFraction), get(maxDepth), get(minSumHessianInLeaf), numTasks, modelStr,
getIsUnbalance, getVerbosity, categoricalIndexes, actualNumClasses, getBoostFromAverage,
getBoostingType, getLambdaL1, getLambdaL2, getIsProvideTrainingMetric,
getMetric, getMinGainToSplit, getMaxDeltaStep, getMaxBinByFeature, getMinDataInLeaf, getSlotNames,
getBoostingType, get(lambdaL1), get(lambdaL2), get(isProvideTrainingMetric),
get(metric), get(minGainToSplit), get(maxDeltaStep), getMaxBinByFeature, get(minDataInLeaf), getSlotNames,
getDelegate, getDartParams, getExecutionParams, getObjectiveParams)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,13 @@ class LightGBMRanker(override val uid: String)
def getTrainParams(numTasks: Int, dataset: Dataset[_], numTasksPerExec: Int): TrainParams = {
val categoricalIndexes = getCategoricalIndexes(dataset.schema(getFeaturesCol))
val modelStr = if (getModelString == null || getModelString.isEmpty) None else get(modelString)
RankerTrainParams(getParallelism, getTopK, getNumIterations, getLearningRate, getNumLeaves,
getMaxBin, getBinSampleCount, getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction,
getBaggingFreq, getBaggingSeed, getEarlyStoppingRound, getImprovementTolerance,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, modelStr,
getVerbosity, categoricalIndexes, getBoostingType, getLambdaL1, getLambdaL2, getMaxPosition, getLabelGain,
getIsProvideTrainingMetric, getMetric, getEvalAt, getMinGainToSplit, getMaxDeltaStep,
getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate, getDartParams,
RankerTrainParams(getParallelism, get(topK), getNumIterations, getLearningRate,
get(numLeaves), get(maxBin), get(binSampleCount), get(baggingFraction), get(posBaggingFraction),
get(negBaggingFraction), get(baggingFreq), get(baggingSeed), getEarlyStoppingRound, getImprovementTolerance,
get(featureFraction), get(maxDepth), get(minSumHessianInLeaf), numTasks, modelStr,
getVerbosity, categoricalIndexes, getBoostingType, get(lambdaL1), get(lambdaL2), getMaxPosition, getLabelGain,
get(isProvideTrainingMetric), get(metric), getEvalAt, get(minGainToSplit), get(maxDeltaStep),
getMaxBinByFeature, get(minDataInLeaf), getSlotNames, getDelegate, getDartParams,
getExecutionParams, getObjectiveParams)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,15 @@ class LightGBMRegressor(override val uid: String)
def getTrainParams(numTasks: Int, dataset: Dataset[_], numTasksPerExec: Int): TrainParams = {
val categoricalIndexes = getCategoricalIndexes(dataset.schema(getFeaturesCol))
val modelStr = if (getModelString == null || getModelString.isEmpty) None else get(modelString)
RegressorTrainParams(getParallelism, getTopK, getNumIterations, getLearningRate, getNumLeaves,
getAlpha, getTweedieVariancePower, getMaxBin, getBinSampleCount, getBaggingFraction, getPosBaggingFraction,
getNegBaggingFraction, getBaggingFreq, getBaggingSeed, getEarlyStoppingRound, getImprovementTolerance,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, modelStr, getVerbosity, categoricalIndexes,
getBoostFromAverage, getBoostingType, getLambdaL1, getLambdaL2, getIsProvideTrainingMetric, getMetric,
getMinGainToSplit, getMaxDeltaStep, getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate,
RegressorTrainParams(getParallelism, get(topK), getNumIterations, getLearningRate,
get(numLeaves), getAlpha, getTweedieVariancePower,
get(maxBin), get(binSampleCount), get(baggingFraction), get(posBaggingFraction),
get(negBaggingFraction), get(baggingFreq), get(baggingSeed), getEarlyStoppingRound, getImprovementTolerance,
get(featureFraction), get(maxDepth), get(minSumHessianInLeaf),
numTasks, modelStr, getVerbosity, categoricalIndexes,
getBoostFromAverage, getBoostingType, get(lambdaL1), get(lambdaL2), get(isProvideTrainingMetric),
get(metric), get(minGainToSplit), get(maxDeltaStep),
getMaxBinByFeature, get(minDataInLeaf), getSlotNames, getDelegate,
getDartParams, getExecutionParams, getObjectiveParams)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,14 @@ private object TrainUtils extends Serializable {

isFinished = updateOneIteration(trainParams, booster, log, iters)

val trainEvalResults: Option[Map[String, Double]] = if (trainParams.isProvideTrainingMetric && !isFinished) {
val evalResults: Array[(String, Double)] = booster.getEvalResults(evalNames, 0)
evalResults.foreach { case (evalName: String, score: Double) => log.info(s"Train $evalName=$score") }
Option(Map(evalResults:_*))
} else {
None
}
val trainEvalResults: Option[Map[String, Double]] =
if (trainParams.isProvideTrainingMetric.getOrElse(false) && !isFinished) {
val evalResults: Array[(String, Double)] = booster.getEvalResults(evalNames, 0)
evalResults.foreach { case (evalName: String, score: Double) => log.info(s"Train $evalName=$score") }
Option(Map(evalResults:_*))
} else {
None
}

val validEvalResults: Option[Map[String, Double]] = if (hasValid && !isFinished) {
val evalResults: Array[(String, Double)] = booster.getEvalResults(evalNames, 1)
Expand Down
Loading

0 comments on commit 336eff5

Please sign in to comment.