diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMBase.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMBase.scala index 1ea65bf157..3d29f9c96b 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMBase.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMBase.scala @@ -6,8 +6,8 @@ package com.microsoft.azure.synapse.ml.lightgbm import com.microsoft.azure.synapse.ml.core.utils.ClusterUtil import com.microsoft.azure.synapse.ml.io.http.SharedSingleton import com.microsoft.azure.synapse.ml.lightgbm.ConnectionState.Finished -import com.microsoft.azure.synapse.ml.lightgbm.LightGBMUtils.{closeConnections, getPartitionId, - getExecutorId, handleConnection, sendDataToExecutors} +import com.microsoft.azure.synapse.ml.lightgbm.LightGBMUtils.{closeConnections, getExecutorId, + getPartitionId, handleConnection, sendDataToExecutors} import com.microsoft.azure.synapse.ml.lightgbm.TaskTrainingMethods.{isWorkerEnabled, prepareDatasets} import com.microsoft.azure.synapse.ml.lightgbm.TrainUtils._ import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster @@ -26,6 +26,7 @@ import org.apache.spark.sql.types._ import java.net.{ServerSocket, Socket} import java.util.concurrent.Executors import scala.collection.immutable.HashSet +import scala.collection.mutable import scala.collection.mutable.ListBuffer import scala.concurrent.duration.{Duration, SECONDS} import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future} @@ -428,22 +429,27 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine val f = Future { var emptyTaskCounter = 0 val hostAndPorts = ListBuffer[(Socket, String)]() + val hostToMinPartition = mutable.Map[String, String]() if (getUseBarrierExecutionMode) { log.info(s"driver using barrier execution mode") def connectToWorkers: Boolean = handleConnection(driverServerSocket, log, - hostAndPorts) == Finished || connectToWorkers + hostAndPorts, hostToMinPartition) == Finished || connectToWorkers connectToWorkers } else { log.info(s"driver expecting $numTasks connections...") while (hostAndPorts.size + emptyTaskCounter < numTasks) { - val connectionResult = handleConnection(driverServerSocket, log, hostAndPorts) + val connectionResult = handleConnection(driverServerSocket, log, hostAndPorts, hostToMinPartition) if (connectionResult == ConnectionState.EmptyTask) emptyTaskCounter += 1 } } // Concatenate with commas, eg: host1:port1,host2:port2, ... etc - val allConnections = hostAndPorts.map(_._2).sorted.mkString(",") + val hostPortsList = hostAndPorts.map(_._2).sortBy(hostPort => { + val host = hostPort.split(":")(0) + Integer.parseInt(hostToMinPartition(host)) + }) + val allConnections = hostPortsList.mkString(",") log.info(s"driver writing back to all connections: $allConnections") // Send data back to all tasks and helper tasks on executors sendDataToExecutors(hostAndPorts, allConnections) diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMUtils.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMUtils.scala index bdcfc3b0eb..801444316c 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMUtils.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/LightGBMUtils.scala @@ -14,6 +14,7 @@ import org.slf4j.Logger import java.io._ import java.net.{ServerSocket, Socket} +import scala.collection.mutable import scala.collection.mutable.ListBuffer /** Helper utilities for LightGBM learners */ @@ -82,10 +83,12 @@ object LightGBMUtils { * @param driverServerSocket The driver socket. * @param log The log4j logger. * @param hostAndPorts A list of host and ports of connected tasks. + * @param hostToMinPartition A list of host to the minimum partition id, used for determinism. * @return The connection status, can be finished for barrier mode, empty task or connected. */ def handleConnection(driverServerSocket: ServerSocket, log: Logger, - hostAndPorts: ListBuffer[(Socket, String)]): ConnectionState = { + hostAndPorts: ListBuffer[(Socket, String)], + hostToMinPartition: mutable.Map[String, String]): ConnectionState = { log.info("driver accepting a new connection...") val driverSocket = driverServerSocket.accept() val reader = new BufferedReader(new InputStreamReader(driverSocket.getInputStream)) @@ -93,15 +96,31 @@ object LightGBMUtils { if (comm == LightGBMConstants.FinishedStatus) { log.info("driver received all tasks from barrier stage") Finished - } else if (comm == LightGBMConstants.IgnoreStatus) { + } else if (comm.startsWith(LightGBMConstants.IgnoreStatus)) { log.info("driver received ignore status from task") + val hostPartition = comm.split(":") + val host = hostPartition(1) + val partitionId = hostPartition(2) + updateHostToMinPartition(hostToMinPartition, host, partitionId) EmptyTask } else { - addSocketAndComm(hostAndPorts, log, comm, driverSocket) + val hostPortPartition = comm.split(":") + val host = hostPortPartition(0) + val port = hostPortPartition(1) + val partitionId = hostPortPartition(2) + updateHostToMinPartition(hostToMinPartition, host, partitionId) + addSocketAndComm(hostAndPorts, log, s"$host:$port", driverSocket) Connected } } + def updateHostToMinPartition(hostToMinPartition: mutable.Map[String, String], + host: String, partitionId: String): Unit = { + if (!hostToMinPartition.contains(host) || hostToMinPartition(host) > partitionId) { + hostToMinPartition(host) = partitionId + } + } + /** Returns an integer ID for the current worker. * @return In cluster, returns the executor id. In local case, returns the partition id. diff --git a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/TrainUtils.scala b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/TrainUtils.scala index ce10394c3e..2c62840f80 100644 --- a/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/TrainUtils.scala +++ b/lightgbm/src/main/scala/com/microsoft/azure/synapse/ml/lightgbm/TrainUtils.scala @@ -244,17 +244,18 @@ private object TrainUtils extends Serializable { io => val driverInput = io(0).asInstanceOf[BufferedReader] val driverOutput = io(1).asInstanceOf[BufferedWriter] + val partitionId = LightGBMUtils.getPartitionId + val taskHost = driverSocket.getLocalAddress.getHostAddress val taskStatus = if (ignoreTask) { - log.info("send empty status to driver") - LightGBMConstants.IgnoreStatus + log.info(s"send empty status to driver with partitionId: $partitionId") + s"${LightGBMConstants.IgnoreStatus}:$taskHost:$partitionId" } else { - val taskHost = driverSocket.getLocalAddress.getHostAddress - val taskInfo = s"$taskHost:$localListenPort" + val taskInfo = s"$taskHost:$localListenPort:$partitionId" log.info(s"send current task info to driver: $taskInfo ") taskInfo } - // Send the current host:port to the driver + // Send the current host:port:partitionId to the driver driverOutput.write(s"$taskStatus\n") driverOutput.flush() // If barrier execution mode enabled, create a barrier across tasks @@ -265,7 +266,7 @@ private object TrainUtils extends Serializable { setFinishedStatus(networkParams, localListenPort, log) } } - if (taskStatus != LightGBMConstants.IgnoreStatus) { + if (!taskStatus.startsWith(LightGBMConstants.IgnoreStatus)) { // Wait to get the list of nodes from the driver val nodes = driverInput.readLine() log.info(s"LightGBM worker got nodes for network init: $nodes")