Skip to content

Commit

Permalink
feat: add number of threads parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed May 21, 2021
1 parent 663d965 commit ca14514
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
}

protected def getExecutionParams(): ExecutionParams = {
ExecutionParams(getChunkSize, getMatrixType)
ExecutionParams(getChunkSize, getMatrixType, getNumThreads)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ trait LightGBMExecutionParams extends Wrappable {

def getMatrixType: String = $(matrixType)
def setMatrixType(value: String): this.type = set(matrixType, value)

val numThreads = new IntParam(this, "numThreads",
"Number of threads for LightGBM. For the best speed, set this to the number of real CPU cores.")
setDefault(numThreads -> 0)

def getNumThreads: Int = $(numThreads)
def setNumThreads(value: Int): this.type = set(numThreads, value)
}

/** Defines common parameters across all LightGBM learners related to learning score evolution.
Expand Down
18 changes: 11 additions & 7 deletions src/main/scala/com/microsoft/ml/spark/lightgbm/LightGBMUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,22 @@ object LightGBMUtils {
numRowsForChunks
}

def getDatasetParams(trainParams: TrainParams): String = {
val datasetParams = s"max_bin=${trainParams.maxBin} is_pre_partition=True " +
s"bin_construct_sample_cnt=${trainParams.binSampleCount} " +
s"num_threads=${trainParams.executionParams.numThreads} "
(if (trainParams.categoricalFeatures.isEmpty) ""
else s"categorical_feature=${trainParams.categoricalFeatures.mkString(",")}")
datasetParams
}

def generateDenseDataset(numRows: Int, numCols: Int, featuresArray: doubleChunkedArray,
referenceDataset: Option[LightGBMDataset],
featureNamesOpt: Option[Array[String]],
trainParams: TrainParams, chunkSize: Int): LightGBMDataset = {
val isRowMajor = 1
val datasetOutPtr = lightgbmlib.voidpp_handle()
val datasetParams = s"max_bin=${trainParams.maxBin} is_pre_partition=True " +
s"bin_construct_sample_cnt=${trainParams.binSampleCount} " +
(if (trainParams.categoricalFeatures.isEmpty) ""
else s"categorical_feature=${trainParams.categoricalFeatures.mkString(",")}")
val datasetParams = getDatasetParams(trainParams)
val data64bitType = lightgbmlibConstants.C_API_DTYPE_FLOAT64
var data: Option[(SWIGTYPE_p_void, SWIGTYPE_p_double)] = None
val numRowsForChunks = getNumRowsForChunksArray(numRows, chunkSize)
Expand Down Expand Up @@ -268,9 +274,7 @@ object LightGBMUtils {
val numCols = sparseRows(0).size

val datasetOutPtr = lightgbmlib.voidpp_handle()
val datasetParams = s"max_bin=${trainParams.maxBin} is_pre_partition=True " +
(if (trainParams.categoricalFeatures.isEmpty) ""
else s"categorical_feature=${trainParams.categoricalFeatures.mkString(",")}")
val datasetParams = getDatasetParams(trainParams)
// Generate the dataset for features
LightGBMUtils.validate(lightgbmlib.LGBM_DatasetCreateFromCSRSpark(
sparseRows.asInstanceOf[Array[Object]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ abstract class TrainParams extends Serializable {
s"max_delta_step=$maxDeltaStep min_data_in_leaf=$minDataInLeaf " +
(if (categoricalFeatures.isEmpty) "" else s"categorical_feature=${categoricalFeatures.mkString(",")} ") +
(if (maxBinByFeature.isEmpty) "" else s"max_bin_by_feature=${maxBinByFeature.mkString(",")} ") +
(if (boostingType == "dart") s"${dartModeParams.toString()}" else "")
(if (boostingType == "dart") s"${dartModeParams.toString()} " else "") +
executionParams.toString()
}
}

Expand Down Expand Up @@ -143,4 +144,8 @@ case class DartModeParams(dropRate: Double, maxDrop: Int, skipDrop: Double,
}
}

case class ExecutionParams(chunkSize: Int, matrixType: String) extends Serializable
case class ExecutionParams(chunkSize: Int, matrixType: String, numThreads: String) extends Serializable {
override def toString(): String = {
s"num_threads=$numThreads "
}
}

0 comments on commit ca14514

Please sign in to comment.