Skip to content

Commit

Permalink
fix: lightgbm getting stuck when empty partition is chosen as the mai…
Browse files Browse the repository at this point in the history
…n worker in singleDatasetMode (#1458)
  • Loading branch information
imatiach-msft authored Mar 31, 2022
1 parent 3eb4661 commit 63c1235
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ package com.microsoft.azure.synapse.ml.lightgbm
import com.microsoft.azure.synapse.ml.core.utils.{ClusterUtil, ParamsStringBuilder}
import com.microsoft.azure.synapse.ml.io.http.SharedSingleton
import com.microsoft.azure.synapse.ml.lightgbm.ConnectionState.Finished
import com.microsoft.azure.synapse.ml.lightgbm.LightGBMUtils.{closeConnections, getExecutorId,
getPartitionId, handleConnection, sendDataToExecutors}
import com.microsoft.azure.synapse.ml.lightgbm.LightGBMUtils.{closeConnections, getExecutorId, getPartitionId,
handleConnection, sendDataToExecutors}
import com.microsoft.azure.synapse.ml.lightgbm.TaskTrainingMethods.{isWorkerEnabled, prepareDatasets}
import com.microsoft.azure.synapse.ml.lightgbm.TrainUtils._
import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster
Expand Down Expand Up @@ -460,6 +460,8 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
}
val useSingleDatasetMode = trainParams.executionParams.useSingleDatasetMode
val emptyPartition = !inputRows.hasNext
// Note: the first valid worker with non-empty partitions sets the main executor worker, other workers read it
if (useSingleDatasetMode && !emptyPartition) sharedState.linkMainExecutorWorker()
val isEnabledWorker = if (!emptyPartition) isWorkerEnabled(trainParams, log, sharedState) else false
// Initialize the native library
LightGBMUtils.initializeNativeLibrary()
Expand All @@ -469,12 +471,7 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
log.warn("LightGBM task encountered empty partition, for best performance ensure no partitions empty")
List[LightGBMBooster]().toIterator
} else {
if (isEnabledWorker) {
log.info(s"LightGBM task listening on: $localListenPort")
if (useSingleDatasetMode) sharedState.helperStartSignal.countDown()
} else {
sharedState.helperStartSignal.await()
}
updateHelperStartSignal(useSingleDatasetMode, sharedState, isEnabledWorker, localListenPort)
val (aggregatedColumns, aggregatedValidationColumns) = prepareDatasets(
inputRows, validationData, sharedState)
// Return booster only from main worker to reduce network communication overhead
Expand All @@ -497,6 +494,24 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
}
}

/** Prints the listening port and, in single dataset mode, forces helper tasks to wait for the main worker
* before continuing to prepare and merge the dataset.
*
* @param useSingleDatasetMode If true, indicates whether SingleDatasetMode is enabled.
* @param sharedState The shared state across spark tasks.
* @param isEnabledWorker Whether the current work is enabled to initialize the network ring of communication.
* @param localListenPort The local port for creating the network ring of communication.
*/
private def updateHelperStartSignal(useSingleDatasetMode: Boolean, sharedState: SharedState,
isEnabledWorker: Boolean, localListenPort: Int) = {
if (isEnabledWorker) {
log.info(s"LightGBM task listening on: $localListenPort")
if (useSingleDatasetMode) sharedState.helperStartSignal.countDown()
} else {
sharedState.helperStartSignal.await()
}
}

/**
* Opens a socket communications channel on the driver, starts a thread that
* waits for the host:port from the executors, and then sends back the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ class SharedDatasetState(columnParams: ColumnParams,
class SharedState(columnParams: ColumnParams,
schema: StructType,
trainParams: BaseTrainParams) {
val mainExecutorWorker: Long = LightGBMUtils.getTaskId
val useSingleDataset: Boolean = trainParams.executionParams.useSingleDatasetMode
val chunkSize: Int = trainParams.executionParams.chunkSize
val matrixType: String = trainParams.executionParams.matrixType
Expand All @@ -92,6 +91,7 @@ class SharedState(columnParams: ColumnParams,
val validationDatasetState: SharedDatasetState = new SharedDatasetState(columnParams, schema, trainParams, this)

@volatile var isSparse: Option[Boolean] = None
@volatile var mainExecutorWorker: Option[Long] = None

def linkIsSparse(isSparse: Boolean): Unit = {
if (this.isSparse.isEmpty) {
Expand All @@ -103,6 +103,16 @@ class SharedState(columnParams: ColumnParams,
}
}

def linkMainExecutorWorker(): Unit = {
if (this.mainExecutorWorker.isEmpty) {
this.synchronized {
if (this.mainExecutorWorker.isEmpty) {
this.mainExecutorWorker = Some(LightGBMUtils.getTaskId)
}
}
}
}

def incrementArrayProcessedSignal(log: Logger): Int = {
datasetState.incrementArrayProcessedSignal(log)
validationDatasetState.incrementArrayProcessedSignal(log)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,13 @@ object TaskTrainingMethods {
* Otherwise, returns true for all tasks.
* @param trainParams The training parameters.
* @param log The logger.
* @param sharedState The shared state.
* @return Whether the current task is enabled.
*/
def isWorkerEnabled(trainParams: BaseTrainParams, log: Logger, sharedState: SharedState): Boolean = {
if (trainParams.executionParams.useSingleDatasetMode) {
// Find all workers in current JVM
val mainExecutorWorker = sharedState.mainExecutorWorker
val myTaskId = LightGBMUtils.getTaskId
val isMainWorker = mainExecutorWorker == myTaskId
log.info(s"Using singleDatasetMode. " +
s"Is main worker: ${isMainWorker} for task id: ${myTaskId} and main task id: ${mainExecutorWorker}")
val isMainWorker = isCurrentTaskMainWorker(log, sharedState)
sharedState.incrementArrayProcessedSignal(log)
if (!isMainWorker) {
sharedState.incrementDoneSignal(log)
Expand All @@ -34,6 +31,21 @@ object TaskTrainingMethods {
}
}

/** Determines if the current task is the main worker in the current JVM.
*
* @param log The logger.
* @param sharedState The shared state.
* @return True if the current task in the main worker, false otherwise.
*/
def isCurrentTaskMainWorker(log: Logger, sharedState: SharedState): Boolean = {
val mainExecutorWorker = sharedState.mainExecutorWorker.get
val myTaskId = LightGBMUtils.getTaskId
val isMainWorker = mainExecutorWorker == myTaskId
log.info(s"Using singleDatasetMode. " +
s"Is main worker: ${isMainWorker} for task id: ${myTaskId} and main task id: ${mainExecutorWorker}")
isMainWorker
}

def prepareDatasets(inputRows: Iterator[Row],
validationData: Option[Broadcast[Array[Row]]],
sharedState: SharedState): (BaseAggregatedColumns, Option[BaseAggregatedColumns]) = {
Expand Down

0 comments on commit 63c1235

Please sign in to comment.