Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: add seed parameters to lightgbm #1387

Merged
merged 3 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@ package com.microsoft.azure.synapse.ml.lightgbm
import com.microsoft.azure.synapse.ml.core.utils.ClusterUtil
import com.microsoft.azure.synapse.ml.io.http.SharedSingleton
import com.microsoft.azure.synapse.ml.lightgbm.ConnectionState.Finished
import com.microsoft.azure.synapse.ml.lightgbm.LightGBMUtils.{closeConnections, handleConnection, sendDataToExecutors}
import com.microsoft.azure.synapse.ml.lightgbm.LightGBMUtils.{closeConnections, getExecutorId,
getPartitionId, handleConnection, sendDataToExecutors}
import com.microsoft.azure.synapse.ml.lightgbm.TaskTrainingMethods.{isWorkerEnabled, prepareDatasets}
import com.microsoft.azure.synapse.ml.lightgbm.TrainUtils._
import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster
Expand All @@ -25,6 +26,7 @@ import org.apache.spark.sql.types._
import java.net.{ServerSocket, Socket}
import java.util.concurrent.Executors
import scala.collection.immutable.HashSet
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.concurrent.duration.{Duration, SECONDS}
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future}
Expand Down Expand Up @@ -255,6 +257,11 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
ExecutionParams(getChunkSize, getMatrixType, execNumThreads, getUseSingleDatasetMode)
}

/**
* Constructs the ColumnParams.
*
* @return ColumnParams object containing the parameters related to LightGBM columns.
*/
protected def getColumnParams: ColumnParams = {
ColumnParams(getLabelCol, getFeaturesCol, get(weightCol), get(initScoreCol), getOptGroupCol)
}
Expand All @@ -268,13 +275,25 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
ObjectiveParams(getObjective, if (isDefined(fobj)) Some(getFObj) else None)
}

/**
* Constructs the SeedParams.
*
* @return SeedParams object containing the parameters related to LightGBM seeds and determinism.
*/
protected def getSeedParams: SeedParams = {
SeedParams(get(seed), get(deterministic), get(baggingSeed), get(featureFractionSeed),
get(extraSeed), get(dropSeed), get(dataRandomSeed), get(objectiveSeed), getBoostingType, getObjective)
}

def getDatasetParams(categoricalIndexes: Array[Int], numThreads: Int): String = {
val seedParam = get(dataRandomSeed).orElse(get(seed))
val datasetParams = s"max_bin=$getMaxBin is_pre_partition=True " +
s"bin_construct_sample_cnt=$getBinSampleCount " +
s"min_data_in_leaf=$getMinDataInLeaf " +
s"num_threads=$numThreads " +
(if (categoricalIndexes.isEmpty) ""
else s"categorical_feature=${categoricalIndexes.mkString(",")}")
else s"categorical_feature=${categoricalIndexes.mkString(",")} ") +
seedParam.map(dataRandomSeedOpt => s"data_random_seed=$dataRandomSeedOpt ").getOrElse("")
datasetParams
}

Expand Down Expand Up @@ -349,6 +368,9 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
schema: StructType,
sharedState: SharedState)
(inputRows: Iterator[Row]): Iterator[LightGBMBooster] = {
if (trainParams.verbosity > 1) {
log.info(s"LightGBM partition $getPartitionId running on executor $getExecutorId")
}
val useSingleDatasetMode = trainParams.executionParams.useSingleDatasetMode
val emptyPartition = !inputRows.hasNext
val isEnabledWorker = if (!emptyPartition) isWorkerEnabled(trainParams, log, sharedState) else false
Expand Down Expand Up @@ -409,22 +431,27 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
val f = Future {
var emptyTaskCounter = 0
val hostAndPorts = ListBuffer[(Socket, String)]()
val hostToMinPartition = mutable.Map[String, String]()
if (getUseBarrierExecutionMode) {
log.info(s"driver using barrier execution mode")

def connectToWorkers: Boolean = handleConnection(driverServerSocket, log,
hostAndPorts) == Finished || connectToWorkers
hostAndPorts, hostToMinPartition) == Finished || connectToWorkers

connectToWorkers
} else {
log.info(s"driver expecting $numTasks connections...")
while (hostAndPorts.size + emptyTaskCounter < numTasks) {
val connectionResult = handleConnection(driverServerSocket, log, hostAndPorts)
val connectionResult = handleConnection(driverServerSocket, log, hostAndPorts, hostToMinPartition)
if (connectionResult == ConnectionState.EmptyTask) emptyTaskCounter += 1
}
}
// Concatenate with commas, eg: host1:port1,host2:port2, ... etc
val allConnections = hostAndPorts.map(_._2).mkString(",")
val hostPortsList = hostAndPorts.map(_._2).sortBy(hostPort => {
val host = hostPort.split(":")(0)
Integer.parseInt(hostToMinPartition(host))
})
val allConnections = hostPortsList.mkString(",")
log.info(s"driver writing back to all connections: $allConnections")
// Send data back to all tasks and helper tasks on executors
sendDataToExecutors(hostAndPorts, allConnections)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class LightGBMClassifier(override val uid: String)
getIsUnbalance, getVerbosity, categoricalIndexes, actualNumClasses, getBoostFromAverage,
getBoostingType, get(lambdaL1), get(lambdaL2), get(isProvideTrainingMetric),
get(metric), get(minGainToSplit), get(maxDeltaStep), getMaxBinByFeature, get(minDataInLeaf), getSlotNames,
getDelegate, getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams)
getDelegate, getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams, getSeedParams)
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMClassificationModel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class LightGBMRanker(override val uid: String)
getVerbosity, categoricalIndexes, getBoostingType, get(lambdaL1), get(lambdaL2), getMaxPosition, getLabelGain,
get(isProvideTrainingMetric), get(metric), getEvalAt, get(minGainToSplit), get(maxDeltaStep),
getMaxBinByFeature, get(minDataInLeaf), getSlotNames, getDelegate, getDartParams,
getExecutionParams(numTasksPerExec), getObjectiveParams)
getExecutionParams(numTasksPerExec), getObjectiveParams, getSeedParams)
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRankerModel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class LightGBMRegressor(override val uid: String)
getBoostFromAverage, getBoostingType, get(lambdaL1), get(lambdaL2), get(isProvideTrainingMetric),
get(metric), get(minGainToSplit), get(maxDeltaStep),
getMaxBinByFeature, get(minDataInLeaf), getSlotNames, getDelegate,
getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams)
getDartParams, getExecutionParams(numTasksPerExec), getObjectiveParams, getSeedParams)
}

def getModel(trainParams: TrainParams, lightGBMBooster: LightGBMBooster): LightGBMRegressionModel = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import org.slf4j.Logger

import java.io._
import java.net.{ServerSocket, Socket}
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

/** Helper utilities for LightGBM learners */
Expand Down Expand Up @@ -82,26 +83,44 @@ object LightGBMUtils {
* @param driverServerSocket The driver socket.
* @param log The log4j logger.
* @param hostAndPorts A list of host and ports of connected tasks.
* @param hostToMinPartition A list of host to the minimum partition id, used for determinism.
* @return The connection status, can be finished for barrier mode, empty task or connected.
*/
def handleConnection(driverServerSocket: ServerSocket, log: Logger,
hostAndPorts: ListBuffer[(Socket, String)]): ConnectionState = {
hostAndPorts: ListBuffer[(Socket, String)],
hostToMinPartition: mutable.Map[String, String]): ConnectionState = {
log.info("driver accepting a new connection...")
val driverSocket = driverServerSocket.accept()
val reader = new BufferedReader(new InputStreamReader(driverSocket.getInputStream))
val comm = reader.readLine()
if (comm == LightGBMConstants.FinishedStatus) {
log.info("driver received all tasks from barrier stage")
Finished
} else if (comm == LightGBMConstants.IgnoreStatus) {
} else if (comm.startsWith(LightGBMConstants.IgnoreStatus)) {
log.info("driver received ignore status from task")
val hostPartition = comm.split(":")
val host = hostPartition(1)
val partitionId = hostPartition(2)
updateHostToMinPartition(hostToMinPartition, host, partitionId)
EmptyTask
} else {
addSocketAndComm(hostAndPorts, log, comm, driverSocket)
val hostPortPartition = comm.split(":")
val host = hostPortPartition(0)
val port = hostPortPartition(1)
val partitionId = hostPortPartition(2)
updateHostToMinPartition(hostToMinPartition, host, partitionId)
addSocketAndComm(hostAndPorts, log, s"$host:$port", driverSocket)
Connected
}
}

def updateHostToMinPartition(hostToMinPartition: mutable.Map[String, String],
host: String, partitionId: String): Unit = {
if (!hostToMinPartition.contains(host) || hostToMinPartition(host) > partitionId) {
hostToMinPartition(host) = partitionId
}
}


/** Returns an integer ID for the current worker.
* @return In cluster, returns the executor id. In local case, returns the partition id.
Expand All @@ -116,6 +135,25 @@ object LightGBMUtils {
idAsInt
}

/** Returns the partition ID for the spark Dataset.
*
* Used to make operations deterministic on same dataset.
*
* @return Returns the partition id.
*/
def getPartitionId: Int = {
val ctx = TaskContext.get
ctx.partitionId
}

/** Returns the executor ID for the spark Dataset.
*
* @return Returns the executor id.
*/
def getExecutorId: String = {
SparkEnv.get.executorId
}

/** Returns true if spark is run in local mode.
* @return True if spark is run in local mode.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,18 @@ private object TrainUtils extends Serializable {
io =>
val driverInput = io(0).asInstanceOf[BufferedReader]
val driverOutput = io(1).asInstanceOf[BufferedWriter]
val partitionId = LightGBMUtils.getPartitionId
val taskHost = driverSocket.getLocalAddress.getHostAddress
val taskStatus =
if (ignoreTask) {
log.info("send empty status to driver")
LightGBMConstants.IgnoreStatus
log.info(s"send empty status to driver with partitionId: $partitionId")
s"${LightGBMConstants.IgnoreStatus}:$taskHost:$partitionId"
} else {
val taskHost = driverSocket.getLocalAddress.getHostAddress
val taskInfo = s"$taskHost:$localListenPort"
val taskInfo = s"$taskHost:$localListenPort:$partitionId"
log.info(s"send current task info to driver: $taskInfo ")
taskInfo
}
// Send the current host:port to the driver
// Send the current host:port:partitionId to the driver
driverOutput.write(s"$taskStatus\n")
driverOutput.flush()
// If barrier execution mode enabled, create a barrier across tasks
Expand All @@ -265,7 +266,7 @@ private object TrainUtils extends Serializable {
setFinishedStatus(networkParams, localListenPort, log)
}
}
if (taskStatus != LightGBMConstants.IgnoreStatus) {
if (!taskStatus.startsWith(LightGBMConstants.IgnoreStatus)) {
// Wait to get the list of nodes from the driver
val nodes = driverInput.readLine()
log.info(s"LightGBM worker got nodes for network init: $nodes")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import org.apache.spark.sql.types.StructType

import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.ListBuffer
import scala.collection.concurrent.TrieMap

private[lightgbm] object ChunkedArrayUtils {
def copyChunkedArray[T: Numeric](chunkedArray: ChunkedArray[T],
Expand Down Expand Up @@ -193,6 +194,8 @@ private[lightgbm] abstract class BaseAggregatedColumns(val chunkSize: Int) exten
*/
protected val rowCount = new AtomicLong(0L)
protected val initScoreCount = new AtomicLong(0L)
protected val pIdToRowCountOffset = new TrieMap[Long, Long]()
protected val pIdToInitScoreCountOffset = new TrieMap[Long, Long]()

protected var numCols = 0

Expand All @@ -216,7 +219,10 @@ private[lightgbm] abstract class BaseAggregatedColumns(val chunkSize: Int) exten

def incrementCount(chunkedCols: BaseChunkedColumns): Unit = {
rowCount.addAndGet(chunkedCols.rowCount)
pIdToRowCountOffset.update(LightGBMUtils.getPartitionId, chunkedCols.rowCount)
initScoreCount.addAndGet(chunkedCols.numInitScores)
pIdToInitScoreCountOffset.update(
LightGBMUtils.getPartitionId, chunkedCols.numInitScores)
}

def addRows(chunkedCols: BaseChunkedColumns): Unit = {
Expand All @@ -232,6 +238,18 @@ private[lightgbm] abstract class BaseAggregatedColumns(val chunkSize: Int) exten
initScores = chunkedCols.initScores.map(_ => new DoubleSwigArray(isc))
initializeFeatures(chunkedCols, rc)
groups = new Array[Any](rc.toInt)
updateConcurrentMapOffsets(pIdToRowCountOffset)
updateConcurrentMapOffsets(pIdToInitScoreCountOffset)
}

protected def updateConcurrentMapOffsets(concurrentIdToOffset: TrieMap[Long, Long],
initialValue: Long = 0L): Unit = {
val sortedKeys = concurrentIdToOffset.keys.toSeq.sorted
sortedKeys.foldRight(initialValue: Long)((key, offset) => {
val partitionRowCount = concurrentIdToOffset(key)
concurrentIdToOffset.update(key, offset)
partitionRowCount + offset
})
}

}
Expand All @@ -254,12 +272,6 @@ private[lightgbm] trait DisjointAggregatedColumns extends BaseAggregatedColumns
}

private[lightgbm] trait SyncAggregatedColumns extends BaseAggregatedColumns {
/**
* Variables for current thread to use in order to update common arrays in parallel
*/
protected val threadRowStartIndex = new AtomicLong(0L)
protected val threadInitScoreStartIndex = new AtomicLong(0L)

/** Adds the rows to the internal data structure.
*/
override def addRows(chunkedCols: BaseChunkedColumns): Unit = {
Expand Down Expand Up @@ -289,10 +301,9 @@ private[lightgbm] trait SyncAggregatedColumns extends BaseAggregatedColumns {
var threadInitScoreStartIndex = 0L
val featureIndexes =
this.synchronized {
val labelsSize = chunkedCols.labels.getAddCount
threadRowStartIndex = this.threadRowStartIndex.getAndAdd(labelsSize.toInt)
val initScoreSize = chunkedCols.initScores.map(_.getAddCount)
initScoreSize.foreach(size => threadInitScoreStartIndex = this.threadInitScoreStartIndex.getAndAdd(size))
val partitionId = LightGBMUtils.getPartitionId
threadRowStartIndex = pIdToRowCountOffset.get(partitionId).get
threadInitScoreStartIndex = chunkedCols.initScores.map(_ => pIdToInitScoreCountOffset(partitionId)).getOrElse(0)
updateThreadLocalIndices(chunkedCols, threadRowStartIndex)
}
ChunkedArrayUtils.copyChunkedArray(chunkedCols.labels, labels, threadRowStartIndex, chunkSize)
Expand Down Expand Up @@ -393,6 +404,8 @@ private[lightgbm] abstract class BaseSparseAggregatedColumns(chunkSize: Int)
*/
protected var indexesCount = new AtomicLong(0L)
protected var indptrCount = new AtomicLong(0L)
protected val pIdToIndexesCountOffset = new TrieMap[Long, Long]()
protected val pIdToIndptrCountOffset = new TrieMap[Long, Long]()

def getNumColsFromChunkedArray(chunkedCols: BaseChunkedColumns): Int = {
chunkedCols.asInstanceOf[SparseChunkedColumns].numCols
Expand All @@ -402,7 +415,9 @@ private[lightgbm] abstract class BaseSparseAggregatedColumns(chunkSize: Int)
super.incrementCount(chunkedCols)
val sparseChunkedCols = chunkedCols.asInstanceOf[SparseChunkedColumns]
indexesCount.addAndGet(sparseChunkedCols.getNumIndexes)
pIdToIndexesCountOffset.update(LightGBMUtils.getPartitionId, sparseChunkedCols.getNumIndexes)
indptrCount.addAndGet(sparseChunkedCols.getNumIndexPointers)
pIdToIndptrCountOffset.update(LightGBMUtils.getPartitionId, sparseChunkedCols.getNumIndexPointers)
}

protected def initializeFeatures(chunkedCols: BaseChunkedColumns, rowCount: Long): Unit = {
Expand All @@ -412,6 +427,8 @@ private[lightgbm] abstract class BaseSparseAggregatedColumns(chunkSize: Int)
values = new DoubleSwigArray(indexesCount)
indexPointers = new IntSwigArray(indptrCount)
indexPointers.setItem(0, 0)
updateConcurrentMapOffsets(pIdToIndexesCountOffset)
updateConcurrentMapOffsets(pIdToIndptrCountOffset, 1L)
}

def getIndexes: IntSwigArray = indexes
Expand Down Expand Up @@ -489,25 +506,16 @@ private[lightgbm] final class SparseAggregatedColumns(chunkSize: Int)
*/
private[lightgbm] final class SparseSyncAggregatedColumns(chunkSize: Int)
extends BaseSparseAggregatedColumns(chunkSize) with SyncAggregatedColumns {
/**
* Variables for current thread to use in order to update common arrays in parallel
*/
protected val threadIndexesStartIndex = new AtomicLong(0L)
protected val threadIndptrStartIndex = new AtomicLong(1L)

override protected def initializeRows(chunkedCols: BaseChunkedColumns): Unit = {
// Add extra 0 for start of indptr in parallel case
this.indptrCount.addAndGet(1L)
super.initializeRows(chunkedCols)
}

protected def updateThreadLocalIndices(chunkedCols: BaseChunkedColumns, threadRowStartIndex: Long): List[Long] = {
val sparseChunkedCols = chunkedCols.asInstanceOf[SparseChunkedColumns]
val indexesSize = sparseChunkedCols.indexes.getAddCount
val threadIndexesStartIndex = this.threadIndexesStartIndex.getAndAdd(indexesSize)

val indPtrSize = sparseChunkedCols.indexPointers.getAddCount
val threadIndPtrStartIndex = this.threadIndptrStartIndex.getAndAdd(indPtrSize)
val partitionId = LightGBMUtils.getPartitionId
val threadIndexesStartIndex = pIdToIndexesCountOffset.get(partitionId).get
val threadIndPtrStartIndex = pIdToIndptrCountOffset.get(partitionId).get
List(threadIndexesStartIndex, threadIndPtrStartIndex)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
package com.microsoft.azure.synapse.ml.lightgbm.dataset

import com.microsoft.azure.synapse.ml.lightgbm.ColumnParams
import com.microsoft.azure.synapse.ml.lightgbm.swig.DoubleChunkedArray
import com.microsoft.ml.lightgbm.{doubleChunkedArray, floatChunkedArray}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.linalg.{DenseVector, SparseVector}
Expand Down
Loading