Skip to content

Commit

Permalink
feat: add custom objective function to lightgbm learners
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed May 24, 2021
1 parent d8bb51f commit 99bdb64
Show file tree
Hide file tree
Showing 20 changed files with 582 additions and 223 deletions.
2 changes: 1 addition & 1 deletion scalastyle-config.xml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
<parameter name="header">^// Copyright \(C\) Microsoft Corporation\. All rights reserved\.
// Licensed under the MIT License\. See LICENSE in project root for information\.

package (?:com\.microsoft\.ml\.spark|org\.apache\.spark|com\.microsoft\.CNTK|com\.microsoft\.ml\.lightgbm)[.
package (?:com\.microsoft\.ml\.spark|org\.apache\.spark|com\.microsoft\.CNTK|com\.microsoft\.ml\.lightgbm|com\.microsoft\.lightgbm)[.
]</parameter>
<parameter name="regex">true</parameter></parameters></check>
<check level="error" class="org.scalastyle.file.WhitespaceEndOfLineChecker" enabled="true"></check>
Expand Down
4 changes: 2 additions & 2 deletions scalastyle-test-config.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
<check level="error" class="org.scalastyle.file.FileLineLengthChecker" enabled="true"><parameters>
<parameter name="maxLineLength">120</parameter></parameters></check>
<check level="error" class="org.scalastyle.file.NewLineAtEofChecker" enabled="true"></check>
<check level="warning" class="org.scalastyle.scalariform.TokenChecker" enabled="true"><parameters>
<check level="warning" class="org.scalastyle.scalariform.TokenChecker" enabled="true"><parameters>
<parameter name="regex">.{33}</parameter></parameters></check>
<check level="error" class="org.scalastyle.file.HeaderMatchesChecker" enabled="true"><parameters>
<parameter name="header">^// Copyright \(C\) Microsoft Corporation\. All rights reserved\.
// Licensed under the MIT License\. See LICENSE in project root for information\.

package (?:com\.microsoft\.ml\.spark|org\.apache\.spark|com\.microsoft\.CNTK|com\.microsoft\.ml\.lightgbm)[.
package (?:com\.microsoft\.ml\.spark|org\.apache\.spark|com\.microsoft\.CNTK|com\.microsoft\.ml\.lightgbm|com\.microsoft\.lightgbm)[.
]</parameter>
<parameter name="regex">true</parameter></parameters></check>
<check level="error" class="org.scalastyle.file.WhitespaceEndOfLineChecker" enabled="true"></check>
Expand Down
13 changes: 13 additions & 0 deletions src/main/scala/com/microsoft/lightgbm/SWIG.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
// Copyright (C) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License. See LICENSE in project root for information.

package com.microsoft.lightgbm

import com.microsoft.ml.lightgbm.SWIGTYPE_p_void

class SwigPtrWrapper(val value: SWIGTYPE_p_void) extends SWIGTYPE_p_void {
/** Helper function to get the underlying pointer address from the SWIG pointer object.
* @return The underlying pointer address as a long.
*/
def getCPtrValue(): Long = SWIGTYPE_p_void.getCPtr(value)
}
19 changes: 19 additions & 0 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMBase.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.spark.core.utils.ClusterUtil
import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.params.{DartModeParams, ExecutionParams, LightGBMParams,
ObjectiveParams, TrainParams}
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
Expand Down Expand Up @@ -180,14 +183,30 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
}
}

/**
* Constructs the DartModeParams
* @return DartModeParams object containing parameters related to dart mode.
*/
protected def getDartParams(): DartModeParams = {
DartModeParams(getDropRate, getMaxDrop, getSkipDrop, getXGBoostDartMode, getUniformDrop)
}

/**
* Constructs the ExecutionParams.
* @return ExecutionParams object containing parameters related to LightGBM execution.
*/
protected def getExecutionParams(): ExecutionParams = {
ExecutionParams(getChunkSize, getMatrixType)
}

/**
* Constructs the ObjectiveParams.
* @return ObjectiveParams object containing parameters related to the objective function.
*/
protected def getObjectiveParams(): ObjectiveParams = {
ObjectiveParams(getObjective, if (isDefined(fobj)) Some(getFObj) else None)
}

/**
* Inner train method for LightGBM learners. Calculates the number of workers,
* creates a driver thread, and runs mapPartitions on the dataset.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.params.{ClassifierTrainParams, LightGBMModelParams,
LightGBMPredictionParams, TrainParams}
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable}
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -46,11 +49,11 @@ class LightGBMClassifier(override val uid: String)
ClassifierTrainParams(getParallelism, getTopK, getNumIterations, getLearningRate, getNumLeaves, getMaxBin,
getBinSampleCount, getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction,
getBaggingFreq, getBaggingSeed, getEarlyStoppingRound, getImprovementTolerance,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, getObjective, modelStr,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, modelStr,
getIsUnbalance, getVerbosity, categoricalIndexes, actualNumClasses, getBoostFromAverage,
getBoostingType, getLambdaL1, getLambdaL2, getIsProvideTrainingMetric,
getMetric, getMinGainToSplit, getMaxDeltaStep, getMaxBinByFeature, getMinDataInLeaf, getSlotNames,
getDelegate, getDartParams(), getExecutionParams())
getDelegate, getDartParams(), getExecutionParams(), getObjectiveParams())
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMClassificationModel = {
Expand All @@ -69,7 +72,7 @@ class LightGBMClassifier(override val uid: String)
}

def stringFromTrainedModel(model: LightGBMClassificationModel): String = {
model.getModel.model
model.getModel.modelStr.get
}

override def copy(extra: ParamMap): LightGBMClassifier = defaultCopy(extra)
Expand Down Expand Up @@ -187,7 +190,7 @@ class LightGBMClassificationModel(override val uid: String)

object LightGBMClassificationModel extends ComplexParamsReadable[LightGBMClassificationModel] {
def loadNativeModelFromFile(filename: String): LightGBMClassificationModel = {
val uid = Identifiable.randomUID("LightGBMClassifier")
val uid = Identifiable.randomUID("LightGBMClassificationModel")
val session = SparkSession.builder().getOrCreate()
val textRdd = session.read.text(filename)
val text = textRdd.collect().map { row => row.getString(0) }.mkString("\n")
Expand All @@ -197,7 +200,7 @@ object LightGBMClassificationModel extends ComplexParamsReadable[LightGBMClassif
}

def loadNativeModelFromString(model: String): LightGBMClassificationModel = {
val uid = Identifiable.randomUID("LightGBMClassifier")
val uid = Identifiable.randomUID("LightGBMClassificationModel")
val lightGBMBooster = new LightGBMBooster(model)
val actualNumClasses = lightGBMBooster.numClasses
new LightGBMClassificationModel(uid).setLightGBMBooster(lightGBMBooster).setActualNumClasses(actualNumClasses)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@

package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.lightgbm.SWIGTYPE_p_void
import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.params.TrainParams
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.StructType
import org.slf4j.Logger
Expand Down Expand Up @@ -40,12 +41,12 @@ trait LightGBMDelegate extends Serializable {
}

def beforeTrainIteration(batchIndex: Int, partitionId: Int, curIters: Int, log: Logger,
trainParams: TrainParams, boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean): Unit = {
trainParams: TrainParams, booster: LightGBMBooster, hasValid: Boolean): Unit = {
// override this function and write code
}

def afterTrainIteration(batchIndex: Int, partitionId: Int, curIters: Int, log: Logger,
trainParams: TrainParams, boosterPtr: Option[SWIGTYPE_p_void], hasValid: Boolean,
trainParams: TrainParams, booster: LightGBMBooster, hasValid: Boolean,
isFinished: Boolean,
trainEvalResults: Option[Map[String, Double]],
validEvalResults: Option[Map[String, Double]]): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.spark.lightgbm.params.LightGBMModelParams
import org.apache.spark.internal.Logging
import org.apache.spark.ml.linalg.{Vector, Vectors}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.params.{LightGBMModelParams, LightGBMPredictionParams,
RankerTrainParams, TrainParams}
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Ranker, RankerModel}
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -51,12 +54,13 @@ class LightGBMRanker(override val uid: String)
def getTrainParams(numTasks: Int, categoricalIndexes: Array[Int], dataset: Dataset[_]): TrainParams = {
val modelStr = if (getModelString == null || getModelString.isEmpty) None else get(modelString)
RankerTrainParams(getParallelism, getTopK, getNumIterations, getLearningRate, getNumLeaves,
getObjective, getMaxBin, getBinSampleCount, getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction,
getMaxBin, getBinSampleCount, getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction,
getBaggingFreq, getBaggingSeed, getEarlyStoppingRound, getImprovementTolerance,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, modelStr,
getVerbosity, categoricalIndexes, getBoostingType, getLambdaL1, getLambdaL2, getMaxPosition, getLabelGain,
getIsProvideTrainingMetric, getMetric, getEvalAt, getMinGainToSplit, getMaxDeltaStep,
getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate, getDartParams(), getExecutionParams())
getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate, getDartParams(), getExecutionParams(),
getObjectiveParams())
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRankerModel = {
Expand All @@ -70,7 +74,7 @@ class LightGBMRanker(override val uid: String)
}

def stringFromTrainedModel(model: LightGBMRankerModel): String = {
model.getModel.model
model.getModel.modelStr.get
}

override def getOptGroupCol: Option[String] = Some(getGroupCol)
Expand Down Expand Up @@ -157,7 +161,7 @@ class LightGBMRankerModel(override val uid: String)

object LightGBMRankerModel extends ComplexParamsReadable[LightGBMRankerModel] {
def loadNativeModelFromFile(filename: String): LightGBMRankerModel = {
val uid = Identifiable.randomUID("LightGBMRanker")
val uid = Identifiable.randomUID("LightGBMRankerModel")
val session = SparkSession.builder().getOrCreate()
val textRdd = session.read.text(filename)
val text = textRdd.collect().map { row => row.getString(0) }.mkString("\n")
Expand All @@ -166,7 +170,7 @@ object LightGBMRankerModel extends ComplexParamsReadable[LightGBMRankerModel] {
}

def loadNativeModelFromString(model: String): LightGBMRankerModel = {
val uid = Identifiable.randomUID("LightGBMRanker")
val uid = Identifiable.randomUID("LightGBMRankerModel")
val lightGBMBooster = new LightGBMBooster(model)
new LightGBMRankerModel(uid).setLightGBMBooster(lightGBMBooster)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@

package com.microsoft.ml.spark.lightgbm

import com.microsoft.ml.spark.lightgbm.booster.LightGBMBooster
import com.microsoft.ml.spark.lightgbm.params.{LightGBMModelParams, LightGBMPredictionParams,
RegressorTrainParams, TrainParams}
import com.microsoft.ml.spark.logging.BasicLogging
import org.apache.spark.ml.{BaseRegressor, ComplexParamsReadable, ComplexParamsWritable}
import org.apache.spark.ml.param._
Expand Down Expand Up @@ -58,12 +61,12 @@ class LightGBMRegressor(override val uid: String)
def getTrainParams(numTasks: Int, categoricalIndexes: Array[Int], dataset: Dataset[_]): TrainParams = {
val modelStr = if (getModelString == null || getModelString.isEmpty) None else get(modelString)
RegressorTrainParams(getParallelism, getTopK, getNumIterations, getLearningRate, getNumLeaves,
getObjective, getAlpha, getTweedieVariancePower, getMaxBin, getBinSampleCount,
getBaggingFraction, getPosBaggingFraction, getNegBaggingFraction, getBaggingFreq, getBaggingSeed,
getEarlyStoppingRound, getImprovementTolerance, getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf,
numTasks, modelStr, getVerbosity, categoricalIndexes, getBoostFromAverage, getBoostingType, getLambdaL1,
getLambdaL2, getIsProvideTrainingMetric, getMetric, getMinGainToSplit, getMaxDeltaStep,
getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate, getDartParams(), getExecutionParams())
getAlpha, getTweedieVariancePower, getMaxBin, getBinSampleCount, getBaggingFraction, getPosBaggingFraction,
getNegBaggingFraction, getBaggingFreq, getBaggingSeed, getEarlyStoppingRound, getImprovementTolerance,
getFeatureFraction, getMaxDepth, getMinSumHessianInLeaf, numTasks, modelStr, getVerbosity, categoricalIndexes,
getBoostFromAverage, getBoostingType, getLambdaL1, getLambdaL2, getIsProvideTrainingMetric, getMetric,
getMinGainToSplit, getMaxDeltaStep, getMaxBinByFeature, getMinDataInLeaf, getSlotNames, getDelegate,
getDartParams(), getExecutionParams(), getObjectiveParams())
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRegressionModel = {
Expand All @@ -77,7 +80,7 @@ class LightGBMRegressor(override val uid: String)
}

def stringFromTrainedModel(model: LightGBMRegressionModel): String = {
model.getModel.model
model.getModel.modelStr.get
}

override def copy(extra: ParamMap): LightGBMRegressor = defaultCopy(extra)
Expand Down Expand Up @@ -134,7 +137,7 @@ class LightGBMRegressionModel(override val uid: String)

object LightGBMRegressionModel extends ComplexParamsReadable[LightGBMRegressionModel] {
def loadNativeModelFromFile(filename: String): LightGBMRegressionModel = {
val uid = Identifiable.randomUID("LightGBMRegressor")
val uid = Identifiable.randomUID("LightGBMRegressionModel")
val session = SparkSession.builder().getOrCreate()
val textRdd = session.read.text(filename)
val text = textRdd.collect().map { row => row.getString(0) }.mkString("\n")
Expand All @@ -143,7 +146,7 @@ object LightGBMRegressionModel extends ComplexParamsReadable[LightGBMRegressionM
}

def loadNativeModelFromString(model: String): LightGBMRegressionModel = {
val uid = Identifiable.randomUID("LightGBMRegressor")
val uid = Identifiable.randomUID("LightGBMRegressionModel")
val lightGBMBooster = new LightGBMBooster(model)
new LightGBMRegressionModel(uid).setLightGBMBooster(lightGBMBooster)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import com.microsoft.ml.lightgbm._
import com.microsoft.ml.spark.core.env.NativeLoader
import com.microsoft.ml.spark.core.utils.ClusterUtil
import com.microsoft.ml.spark.featurize.{Featurize, FeaturizeUtilities}
import com.microsoft.ml.spark.lightgbm.dataset.LightGBMDataset
import com.microsoft.ml.spark.lightgbm.params.TrainParams
import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.ml.PipelineModel
import org.apache.spark.ml.attribute._
Expand Down Expand Up @@ -245,7 +247,7 @@ object LightGBMUtils {
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromMats(featuresArray.get_chunks_count().toInt,
featuresArray.data_as_void(), data64bitType,
numRowsForChunks, numCols,
isRowMajor, datasetParams, referenceDataset.map(_.dataset).orNull, datasetOutPtr),
isRowMajor, datasetParams, referenceDataset.map(_.datasetPtr).orNull, datasetOutPtr),
"Dataset create")
} finally {
featuresArray.release()
Expand Down Expand Up @@ -275,7 +277,7 @@ object LightGBMUtils {
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromCSRSpark(
sparseRows.asInstanceOf[Array[Object]],
sparseRows.length,
numCols, datasetParams, referenceDataset.map(_.dataset).orNull,
numCols, datasetParams, referenceDataset.map(_.datasetPtr).orNull,
datasetOutPtr),
"Dataset create")
val dataset = new LightGBMDataset(lightgbmlib.voidpp_value(datasetOutPtr))
Expand Down
Loading

0 comments on commit 99bdb64

Please sign in to comment.