From 88e1486689d4057640a8261d6a9417271eb01ef0 Mon Sep 17 00:00:00 2001 From: lxning Date: Thu, 21 Sep 2023 12:18:51 -0700 Subject: [PATCH] allow paralllevel as 1 to start torchrun npro-per-node --- .../serve/archive/model/ModelConfig.java | 7 +++--- .../serve/archive/model/ModelConfigTest.java | 2 +- .../java/org/pytorch/serve/wlm/Model.java | 6 ++--- .../org/pytorch/serve/wlm/ModelManager.java | 2 +- .../pytorch/serve/wlm/WorkLoadManager.java | 11 ++++++---- .../pytorch/serve/wlm/WorkerLifeCycle.java | 6 ++--- .../org/pytorch/serve/wlm/WorkerThread.java | 22 +++++++++++++------ 7 files changed, 33 insertions(+), 23 deletions(-) diff --git a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java index c80235fc74..f4765a0aec 100644 --- a/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java +++ b/frontend/archive/src/main/java/org/pytorch/serve/archive/model/ModelConfig.java @@ -32,7 +32,7 @@ public class ModelConfig { */ private List deviceIds; /** this variable is auto calculated based on torchrun nproc-per-node. */ - private int parallelLevel = 1; + private int parallelLevel; /** the model parallel type can be tp, pp, pptp */ private ParallelType parallelType = ParallelType.NONE; /** torchrun config */ @@ -247,9 +247,8 @@ public int getParallelLevel() { } public void setParallelLevel(int parallelLevel) { - if (parallelLevel <= 0) { - logger.warn("Invalid parallelLevel:{}, set as 1", parallelLevel); - this.parallelLevel = 1; + if (parallelLevel < 0) { + logger.warn("Invalid parallelLevel:{}, set as 0", parallelLevel); return; } this.parallelLevel = parallelLevel; diff --git a/frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java b/frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java index bc567165a5..7a6171d198 100644 --- a/frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java +++ b/frontend/archive/src/test/java/org/pytorch/serve/archive/model/ModelConfigTest.java @@ -43,7 +43,7 @@ public void testInvalidYamlConfig() throws InvalidModelException, IOException { Assert.assertEquals(modelConfig.getMaxBatchDelay(), 100); Assert.assertEquals(modelConfig.getResponseTimeout(), 120); Assert.assertNotEquals(modelConfig.getDeviceType(), ModelConfig.DeviceType.GPU); - Assert.assertEquals(modelConfig.getParallelLevel(), 1); + Assert.assertEquals(modelConfig.getParallelLevel(), 0); Assert.assertNotEquals(modelConfig.getParallelType(), ModelConfig.ParallelType.PPTP); Assert.assertNull(modelConfig.getDeviceIds()); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java index b8e5fc414b..1bc2bd58cf 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/Model.java @@ -40,7 +40,7 @@ public class Model { private int maxWorkers; private int batchSize; private int maxBatchDelay; - private int parallelLevel = 1; + private int parallelLevel; private long maxRetryTimeoutInMill = 5 * 60 * 1000; private long clientTimeoutInMills; private ModelConfig.ParallelType parallelType = ModelConfig.ParallelType.NONE; @@ -69,7 +69,7 @@ public class Model { public Model(ModelArchive modelArchive, int queueSize) { this.modelArchive = modelArchive; if (modelArchive != null && modelArchive.getModelConfig() != null) { - if (modelArchive.getModelConfig().getParallelLevel() > 1 + if (modelArchive.getModelConfig().getParallelLevel() > 0 && modelArchive.getModelConfig().getParallelType() != ModelConfig.ParallelType.NONE) { parallelLevel = modelArchive.getModelConfig().getParallelLevel(); @@ -136,7 +136,7 @@ public JsonObject getModelState(boolean isDefaultVersion) { modelInfo.addProperty(BATCH_SIZE, getBatchSize()); modelInfo.addProperty(MAX_BATCH_DELAY, getMaxBatchDelay()); modelInfo.addProperty(RESPONSE_TIMEOUT, getResponseTimeout()); - if (parallelLevel > 1) { + if (parallelLevel > 0) { modelInfo.addProperty(PARALLEL_LEVEL, parallelLevel); } diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java index d4060566cf..0b9519beb9 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/ModelManager.java @@ -451,7 +451,7 @@ public CompletableFuture updateModel( throw new ModelVersionNotFoundException( "Model version: " + versionId + " does not exist for model: " + modelName); } - if (model.getParallelLevel() > 1 && model.getDeviceType() == ModelConfig.DeviceType.GPU) { + if (model.getParallelLevel() > 0 && model.getDeviceType() == ModelConfig.DeviceType.GPU) { /** * Current capacity check for LMI is based on single node. TODO: multiple nodes check * will be based on --proc-per-node + numCores. diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java index d944e9592d..383fbdfafb 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkLoadManager.java @@ -211,14 +211,17 @@ private void addThreads( int gpuId = -1; if (maxGpu > 0) { - if (model.isHasCfgDeviceIds() || model.getParallelLevel() > 1) { + if (model.isHasCfgDeviceIds() || model.getParallelLevel() > 0) { gpuId = model.getGpuCounter() .getAndAccumulate( maxGpu, (prev, maxGpuId) -> - (prev + model.getParallelLevel()) % maxGpuId); - if (model.getParallelLevel() == 1) { + (prev + model.getParallelLevel() > 0 + ? model.getParallelLevel() + : 1) + % maxGpuId); + if (model.getParallelLevel() == 0) { gpuId = model.getDeviceIds().get(gpuId); } } else { @@ -230,7 +233,7 @@ private void addThreads( BatchAggregator aggregator = new BatchAggregator(model); int currentPort = - model.getParallelLevel() > 1 + model.getParallelLevel() > 0 ? configManager.isDebug() ? distributionPort.get() : distributionPort.getAndAdd(model.getParallelLevel()) diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java index 73e481321a..4ee74e88ad 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerLifeCycle.java @@ -115,9 +115,9 @@ public void startWorker(int port, String deviceIds) modelPath.getAbsolutePath(), model.getModelArchive().getManifest().getModel().getHandler()))); - if (model.getParallelLevel() > 1) { + if (model.getParallelLevel() > 0) { attachRunner(argl, envp, port, deviceIds); - } else if (model.getParallelLevel() == 1) { + } else if (model.getParallelLevel() == 0) { argl.add(EnvironmentUtils.getPythonRunTime(model)); } @@ -153,7 +153,7 @@ public void startWorker(int port, String deviceIds) argl.add(configManager.getMetricsConfigPath()); try { - latch = new CountDownLatch(model.getParallelLevel()); + latch = new CountDownLatch(model.getParallelLevel() > 0 ? model.getParallelLevel() : 1); String[] args = argl.toArray(new String[argl.size()]); String[] envs = envp.toArray(new String[envp.size()]); diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java index 793a2c1d1b..edc6863a21 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java @@ -99,7 +99,9 @@ public WorkerThread( this.listener = listener; startTime = System.currentTimeMillis(); lifeCycle = new WorkerLifeCycle(configManager, model); - replies = new ArrayBlockingQueue<>(model.getParallelLevel()); + replies = + new ArrayBlockingQueue<>( + model.getParallelLevel() > 0 ? model.getParallelLevel() : 1); this.workerThreadTimeMetric = MetricCache.getInstance().getMetricFrontend("WorkerThreadTime"); this.workerLoadTimeMetric = MetricCache.getInstance().getMetricFrontend("WorkerLoadTime"); @@ -192,10 +194,10 @@ public void run() { || ((req.getCommand() == WorkerCommands.PREDICT || req.getCommand() == WorkerCommands.STREAMPREDICT) - && model.getParallelLevel() > 1 + && model.getParallelLevel() > 0 && model.getParallelType() != ModelConfig.ParallelType.PP) - ? model.getParallelLevel() + ? model.getParallelLevel() > 0 ? model.getParallelLevel() : 1 : 1; for (int i = 0; backendChannel.size() > 0 && i < repeats; i++) { backendChannel.get(i).writeAndFlush(req).sync(); @@ -301,7 +303,10 @@ public void run() { // WorkerThread is running in thread pool, the thread will be assigned to next // Runnable once this worker is finished. If currentThread keep holding the reference // of the thread, currentThread.interrupt() might kill next worker. - for (int i = 0; backendChannel.size() > 0 && i < model.getParallelLevel(); i++) { + for (int i = 0; + backendChannel.size() > 0 + && i < (model.getParallelLevel() > 0 ? model.getParallelLevel() : 1); + i++) { backendChannel.get(i).disconnect(); } currentThread.set(null); @@ -342,7 +347,7 @@ private void connect() throws WorkerInitializationException, InterruptedExceptio String modelName = model.getModelName(); String modelVersion = model.getVersion(); setState(WorkerState.WORKER_STARTED, HttpURLConnection.HTTP_OK); - final int parallelLevel = model.getParallelLevel(); + final int parallelLevel = model.getParallelLevel() > 0 ? model.getParallelLevel() : 1; final CountDownLatch latch = new CountDownLatch(parallelLevel); final int responseBufferSize = configManager.getMaxResponseSize(); try { @@ -445,7 +450,10 @@ public int getPid() { public void shutdown() { running.set(false); setState(WorkerState.WORKER_SCALED_DOWN, HttpURLConnection.HTTP_OK); - for (int i = 0; backendChannel.size() > 0 && i < model.getParallelLevel(); i++) { + for (int i = 0; + backendChannel.size() > 0 + && i < (model.getParallelLevel() > 0 ? model.getParallelLevel() : 1); + i++) { if (backendChannel.get(i) != null) { backendChannel.get(i).close(); } @@ -518,7 +526,7 @@ public void retry() { private String getDeviceIds() { List deviceIds; - if (gpuId == -1 || model.getParallelLevel() == 1) { + if (gpuId == -1 || model.getParallelLevel() == 0) { return null; } else if (model.isHasCfgDeviceIds()) { return model.getDeviceIds().subList(gpuId, gpuId + model.getParallelLevel()).stream()