Skip to content

Commit

Permalink
sort based on min partition per machine
Browse files Browse the repository at this point in the history
  • Loading branch information
imatiach-msft committed Mar 17, 2022
1 parent 3fc9460 commit cd399dd
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ package com.microsoft.azure.synapse.ml.lightgbm
import com.microsoft.azure.synapse.ml.core.utils.ClusterUtil
import com.microsoft.azure.synapse.ml.io.http.SharedSingleton
import com.microsoft.azure.synapse.ml.lightgbm.ConnectionState.Finished
import com.microsoft.azure.synapse.ml.lightgbm.LightGBMUtils.{closeConnections, getPartitionId,
getExecutorId, handleConnection, sendDataToExecutors}
import com.microsoft.azure.synapse.ml.lightgbm.LightGBMUtils.{closeConnections, getExecutorId,
getPartitionId, handleConnection, sendDataToExecutors}
import com.microsoft.azure.synapse.ml.lightgbm.TaskTrainingMethods.{isWorkerEnabled, prepareDatasets}
import com.microsoft.azure.synapse.ml.lightgbm.TrainUtils._
import com.microsoft.azure.synapse.ml.lightgbm.booster.LightGBMBooster
Expand All @@ -26,6 +26,7 @@ import org.apache.spark.sql.types._
import java.net.{ServerSocket, Socket}
import java.util.concurrent.Executors
import scala.collection.immutable.HashSet
import scala.collection.mutable
import scala.collection.mutable.ListBuffer
import scala.concurrent.duration.{Duration, SECONDS}
import scala.concurrent.{Await, ExecutionContext, ExecutionContextExecutor, Future}
Expand Down Expand Up @@ -428,22 +429,27 @@ trait LightGBMBase[TrainedModel <: Model[TrainedModel]] extends Estimator[Traine
val f = Future {
var emptyTaskCounter = 0
val hostAndPorts = ListBuffer[(Socket, String)]()
val hostToMinPartition = mutable.Map[String, String]()
if (getUseBarrierExecutionMode) {
log.info(s"driver using barrier execution mode")

def connectToWorkers: Boolean = handleConnection(driverServerSocket, log,
hostAndPorts) == Finished || connectToWorkers
hostAndPorts, hostToMinPartition) == Finished || connectToWorkers

connectToWorkers
} else {
log.info(s"driver expecting $numTasks connections...")
while (hostAndPorts.size + emptyTaskCounter < numTasks) {
val connectionResult = handleConnection(driverServerSocket, log, hostAndPorts)
val connectionResult = handleConnection(driverServerSocket, log, hostAndPorts, hostToMinPartition)
if (connectionResult == ConnectionState.EmptyTask) emptyTaskCounter += 1
}
}
// Concatenate with commas, eg: host1:port1,host2:port2, ... etc
val allConnections = hostAndPorts.map(_._2).sorted.mkString(",")
val hostPortsList = hostAndPorts.map(_._2).sortBy(hostPort => {
val host = hostPort.split(":")(0)
Integer.parseInt(hostToMinPartition(host))
})
val allConnections = hostPortsList.mkString(",")
log.info(s"driver writing back to all connections: $allConnections")
// Send data back to all tasks and helper tasks on executors
sendDataToExecutors(hostAndPorts, allConnections)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import org.slf4j.Logger

import java.io._
import java.net.{ServerSocket, Socket}
import scala.collection.mutable
import scala.collection.mutable.ListBuffer

/** Helper utilities for LightGBM learners */
Expand Down Expand Up @@ -82,26 +83,44 @@ object LightGBMUtils {
* @param driverServerSocket The driver socket.
* @param log The log4j logger.
* @param hostAndPorts A list of host and ports of connected tasks.
* @param hostToMinPartition A list of host to the minimum partition id, used for determinism.
* @return The connection status, can be finished for barrier mode, empty task or connected.
*/
def handleConnection(driverServerSocket: ServerSocket, log: Logger,
hostAndPorts: ListBuffer[(Socket, String)]): ConnectionState = {
hostAndPorts: ListBuffer[(Socket, String)],
hostToMinPartition: mutable.Map[String, String]): ConnectionState = {
log.info("driver accepting a new connection...")
val driverSocket = driverServerSocket.accept()
val reader = new BufferedReader(new InputStreamReader(driverSocket.getInputStream))
val comm = reader.readLine()
if (comm == LightGBMConstants.FinishedStatus) {
log.info("driver received all tasks from barrier stage")
Finished
} else if (comm == LightGBMConstants.IgnoreStatus) {
} else if (comm.startsWith(LightGBMConstants.IgnoreStatus)) {
log.info("driver received ignore status from task")
val hostPartition = comm.split(":")
val host = hostPartition(1)
val partitionId = hostPartition(2)
updateHostToMinPartition(hostToMinPartition, host, partitionId)
EmptyTask
} else {
addSocketAndComm(hostAndPorts, log, comm, driverSocket)
val hostPortPartition = comm.split(":")
val host = hostPortPartition(0)
val port = hostPortPartition(1)
val partitionId = hostPortPartition(2)
updateHostToMinPartition(hostToMinPartition, host, partitionId)
addSocketAndComm(hostAndPorts, log, s"$host:$port", driverSocket)
Connected
}
}

def updateHostToMinPartition(hostToMinPartition: mutable.Map[String, String],
host: String, partitionId: String): Unit = {
if (!hostToMinPartition.contains(host) || hostToMinPartition(host) > partitionId) {
hostToMinPartition(host) = partitionId
}
}


/** Returns an integer ID for the current worker.
* @return In cluster, returns the executor id. In local case, returns the partition id.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,17 +244,18 @@ private object TrainUtils extends Serializable {
io =>
val driverInput = io(0).asInstanceOf[BufferedReader]
val driverOutput = io(1).asInstanceOf[BufferedWriter]
val partitionId = LightGBMUtils.getPartitionId
val taskHost = driverSocket.getLocalAddress.getHostAddress
val taskStatus =
if (ignoreTask) {
log.info("send empty status to driver")
LightGBMConstants.IgnoreStatus
log.info(s"send empty status to driver with partitionId: $partitionId")
s"${LightGBMConstants.IgnoreStatus}:$taskHost:$partitionId"
} else {
val taskHost = driverSocket.getLocalAddress.getHostAddress
val taskInfo = s"$taskHost:$localListenPort"
val taskInfo = s"$taskHost:$localListenPort:$partitionId"
log.info(s"send current task info to driver: $taskInfo ")
taskInfo
}
// Send the current host:port to the driver
// Send the current host:port:partitionId to the driver
driverOutput.write(s"$taskStatus\n")
driverOutput.flush()
// If barrier execution mode enabled, create a barrier across tasks
Expand All @@ -265,7 +266,7 @@ private object TrainUtils extends Serializable {
setFinishedStatus(networkParams, localListenPort, log)
}
}
if (taskStatus != LightGBMConstants.IgnoreStatus) {
if (!taskStatus.startsWith(LightGBMConstants.IgnoreStatus)) {
// Wait to get the list of nodes from the driver
val nodes = driverInput.readLine()
log.info(s"LightGBM worker got nodes for network init: $nodes")
Expand Down

0 comments on commit cd399dd

Please sign in to comment.