Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow paralllevel as 1 to start torchrun npro-per-node #2608

Merged
merged 4 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public class ModelConfig {
*/
private List<Integer> deviceIds;
/** this variable is auto calculated based on torchrun nproc-per-node. */
private int parallelLevel = 1;
private int parallelLevel;
/** the model parallel type can be tp, pp, pptp */
private ParallelType parallelType = ParallelType.NONE;
/** torchrun config */
Expand Down Expand Up @@ -259,9 +259,8 @@ public int getParallelLevel() {
}

public void setParallelLevel(int parallelLevel) {
if (parallelLevel <= 0) {
logger.warn("Invalid parallelLevel:{}, set as 1", parallelLevel);
this.parallelLevel = 1;
if (parallelLevel < 0) {
logger.warn("Invalid parallelLevel:{}, set as 0", parallelLevel);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you set to 1 and then to -1 the level stays at 1 and is not set to 0 as indicated by the warning

return;
}
this.parallelLevel = parallelLevel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public void testInvalidYamlConfig() throws InvalidModelException, IOException {
Assert.assertEquals(modelConfig.getMaxBatchDelay(), 100);
Assert.assertEquals(modelConfig.getResponseTimeout(), 120);
Assert.assertNotEquals(modelConfig.getDeviceType(), ModelConfig.DeviceType.GPU);
Assert.assertEquals(modelConfig.getParallelLevel(), 1);
Assert.assertEquals(modelConfig.getParallelLevel(), 0);
Assert.assertNotEquals(modelConfig.getParallelType(), ModelConfig.ParallelType.PPTP);
Assert.assertNull(modelConfig.getDeviceIds());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class Model {
private int maxWorkers;
private int batchSize;
private int maxBatchDelay;
private int parallelLevel = 1;
private int parallelLevel;
private long maxRetryTimeoutInMill = 5 * 60 * 1000;
private long clientTimeoutInMills;
private ModelConfig.ParallelType parallelType = ModelConfig.ParallelType.NONE;
Expand Down Expand Up @@ -71,7 +71,7 @@ public Model(ModelArchive modelArchive, int queueSize) {
this.modelArchive = modelArchive;
if (modelArchive != null && modelArchive.getModelConfig() != null) {
continuousBatching = modelArchive.getModelConfig().isContinuousBatching();
if (modelArchive.getModelConfig().getParallelLevel() > 1
if (modelArchive.getModelConfig().getParallelLevel() > 0
&& modelArchive.getModelConfig().getParallelType()
!= ModelConfig.ParallelType.NONE) {
parallelLevel = modelArchive.getModelConfig().getParallelLevel();
Expand Down Expand Up @@ -138,7 +138,7 @@ public JsonObject getModelState(boolean isDefaultVersion) {
modelInfo.addProperty(BATCH_SIZE, getBatchSize());
modelInfo.addProperty(MAX_BATCH_DELAY, getMaxBatchDelay());
modelInfo.addProperty(RESPONSE_TIMEOUT, getResponseTimeout());
if (parallelLevel > 1) {
if (parallelLevel > 0) {
modelInfo.addProperty(PARALLEL_LEVEL, parallelLevel);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ public CompletableFuture<Integer> updateModel(
throw new ModelVersionNotFoundException(
"Model version: " + versionId + " does not exist for model: " + modelName);
}
if (model.getParallelLevel() > 1 && model.getDeviceType() == ModelConfig.DeviceType.GPU) {
if (model.getParallelLevel() > 0 && model.getDeviceType() == ModelConfig.DeviceType.GPU) {
/**
* Current capacity check for LMI is based on single node. TODO: multiple nodes check
* will be based on --proc-per-node + numCores.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,14 +211,17 @@ private void addThreads(
int gpuId = -1;

if (maxGpu > 0) {
if (model.isHasCfgDeviceIds() || model.getParallelLevel() > 1) {
if (model.isHasCfgDeviceIds() || model.getParallelLevel() > 0) {
gpuId =
model.getGpuCounter()
.getAndAccumulate(
maxGpu,
(prev, maxGpuId) ->
(prev + model.getParallelLevel()) % maxGpuId);
if (model.getParallelLevel() == 1) {
(prev + model.getParallelLevel() > 0
? model.getParallelLevel()
: 1)
% maxGpuId);
if (model.getParallelLevel() == 0) {
gpuId = model.getDeviceIds().get(gpuId);
}
} else {
Expand All @@ -235,7 +238,7 @@ private void addThreads(
aggregator = new BatchAggregator(model);
}
int currentPort =
model.getParallelLevel() > 1
model.getParallelLevel() > 0
? configManager.isDebug()
? distributionPort.get()
: distributionPort.getAndAdd(model.getParallelLevel())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ public void startWorker(int port, String deviceIds)
modelPath.getAbsolutePath(),
model.getModelArchive().getManifest().getModel().getHandler())));

if (model.getParallelLevel() > 1) {
if (model.getParallelLevel() > 0) {
attachRunner(argl, envp, port, deviceIds);
} else if (model.getParallelLevel() == 1) {
} else if (model.getParallelLevel() == 0) {
argl.add(EnvironmentUtils.getPythonRunTime(model));
}

Expand Down Expand Up @@ -153,7 +153,7 @@ public void startWorker(int port, String deviceIds)
argl.add(configManager.getMetricsConfigPath());

try {
latch = new CountDownLatch(model.getParallelLevel());
latch = new CountDownLatch(model.getParallelLevel() > 0 ? model.getParallelLevel() : 1);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you could replace this and in the following with a more concise max(1, model.getParallelLevel())


String[] args = argl.toArray(new String[argl.size()]);
String[] envs = envp.toArray(new String[envp.size()]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,9 @@ public WorkerThread(
this.listener = listener;
startTime = System.currentTimeMillis();
lifeCycle = new WorkerLifeCycle(configManager, model);
replies = new ArrayBlockingQueue<>(model.getParallelLevel());
replies =
new ArrayBlockingQueue<>(
model.getParallelLevel() > 0 ? model.getParallelLevel() : 1);
this.workerThreadTimeMetric =
MetricCache.getInstance().getMetricFrontend("WorkerThreadTime");
this.workerLoadTimeMetric = MetricCache.getInstance().getMetricFrontend("WorkerLoadTime");
Expand Down Expand Up @@ -198,10 +200,10 @@ public void run() {
|| ((req.getCommand() == WorkerCommands.PREDICT
|| req.getCommand()
== WorkerCommands.STREAMPREDICT)
&& model.getParallelLevel() > 1
&& model.getParallelLevel() > 0
&& model.getParallelType()
!= ModelConfig.ParallelType.PP)
? model.getParallelLevel()
? model.getParallelLevel() > 0 ? model.getParallelLevel() : 1
: 1;
for (int i = 0; backendChannel.size() > 0 && i < repeats; i++) {
backendChannel.get(i).writeAndFlush(req).sync();
Expand Down Expand Up @@ -305,7 +307,10 @@ public void run() {
// WorkerThread is running in thread pool, the thread will be assigned to next
// Runnable once this worker is finished. If currentThread keep holding the reference
// of the thread, currentThread.interrupt() might kill next worker.
for (int i = 0; backendChannel.size() > 0 && i < model.getParallelLevel(); i++) {
for (int i = 0;
backendChannel.size() > 0
&& i < (model.getParallelLevel() > 0 ? model.getParallelLevel() : 1);
i++) {
backendChannel.get(i).disconnect();
}
currentThread.set(null);
Expand Down Expand Up @@ -346,7 +351,7 @@ private void connect() throws WorkerInitializationException, InterruptedExceptio
String modelName = model.getModelName();
String modelVersion = model.getVersion();
setState(WorkerState.WORKER_STARTED, HttpURLConnection.HTTP_OK);
final int parallelLevel = model.getParallelLevel();
final int parallelLevel = model.getParallelLevel() > 0 ? model.getParallelLevel() : 1;
final CountDownLatch latch = new CountDownLatch(parallelLevel);
final int responseBufferSize = configManager.getMaxResponseSize();
try {
Expand Down Expand Up @@ -449,7 +454,10 @@ public int getPid() {
public void shutdown() {
running.set(false);
setState(WorkerState.WORKER_SCALED_DOWN, HttpURLConnection.HTTP_OK);
for (int i = 0; backendChannel.size() > 0 && i < model.getParallelLevel(); i++) {
for (int i = 0;
backendChannel.size() > 0
&& i < (model.getParallelLevel() > 0 ? model.getParallelLevel() : 1);
i++) {
if (backendChannel.get(i) != null) {
backendChannel.get(i).close();
}
Expand Down Expand Up @@ -522,7 +530,7 @@ public void retry() {

private String getDeviceIds() {
List<Integer> deviceIds;
if (gpuId == -1 || model.getParallelLevel() == 1) {
if (gpuId == -1 || model.getParallelLevel() == 0) {
return null;
} else if (model.isHasCfgDeviceIds()) {
return model.getDeviceIds().subList(gpuId, gpuId + model.getParallelLevel()).stream()
Expand Down
Loading