diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/DescribeModelResponse.java b/frontend/server/src/main/java/org/pytorch/serve/http/DescribeModelResponse.java index 68c352fad2..007f3a2eb2 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/DescribeModelResponse.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/DescribeModelResponse.java @@ -122,13 +122,21 @@ public void setWorkers(List workers) { } public void addWorker( - String id, long startTime, boolean isRunning, int gpuId, long memoryUsage) { + String id, + long startTime, + boolean isRunning, + int gpuId, + long memoryUsage, + int pid, + String gpuUsage) { Worker worker = new Worker(); worker.setId(id); worker.setStartTime(new Date(startTime)); worker.setStatus(isRunning ? "READY" : "UNLOADING"); - worker.setGpu(gpuId >= 0); worker.setMemoryUsage(memoryUsage); + worker.setPid(pid); + worker.setGpu(gpuId >= 0); + worker.setGpuUsage(gpuUsage); workers.add(worker); } @@ -145,11 +153,29 @@ public static final class Worker { private String id; private Date startTime; private String status; - private boolean gpu; private long memoryUsage; + private int pid; + private boolean gpu; + private String gpuUsage; public Worker() {} + public String getGpuUsage() { + return gpuUsage; + } + + public void setGpuUsage(String gpuUsage) { + this.gpuUsage = gpuUsage; + } + + public int getPid() { + return pid; + } + + public void setPid(int pid) { + this.pid = pid; + } + public String getId() { return id; } diff --git a/frontend/server/src/main/java/org/pytorch/serve/http/ManagementRequestHandler.java b/frontend/server/src/main/java/org/pytorch/serve/http/ManagementRequestHandler.java index 31559835cd..25cfe450ea 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/http/ManagementRequestHandler.java +++ b/frontend/server/src/main/java/org/pytorch/serve/http/ManagementRequestHandler.java @@ -176,7 +176,9 @@ private DescribeModelResponse createModelResponse( boolean isRunning = worker.isRunning(); int gpuId = worker.getGpuId(); long memory = worker.getMemory(); - resp.addWorker(workerId, startTime, isRunning, gpuId, memory); + int pid = worker.getPid(); + String gpuUsage = worker.getGpuUsage(); + resp.addWorker(workerId, startTime, isRunning, gpuId, memory, pid, gpuUsage); } return resp; diff --git a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java index d1fcf431c5..de3ff10cdc 100644 --- a/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java +++ b/frontend/server/src/main/java/org/pytorch/serve/wlm/WorkerThread.java @@ -10,8 +10,12 @@ import io.netty.channel.EventLoopGroup; import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.HttpResponseStatus; +import java.io.BufferedReader; import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; import java.net.SocketAddress; +import java.nio.charset.StandardCharsets; import java.util.UUID; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CountDownLatch; @@ -76,6 +80,59 @@ public WorkerState getState() { return state; } + public String getGpuUsage() { + Process process; + StringBuffer gpuUsage = new StringBuffer(); + if (gpuId >= 0) { + try { + // TODO : add a generic code to capture gpu details for different devices instead of + // just NVIDIA + process = + Runtime.getRuntime() + .exec( + "nvidia-smi -i " + + gpuId + + " --query-gpu=utilization.gpu,utilization.memory,memory.used --format=csv"); + process.waitFor(); + int exitCode = process.exitValue(); + if (exitCode != 0) { + gpuUsage.append("failed to obtained gpu usage"); + InputStream error = process.getErrorStream(); + for (int i = 0; i < error.available(); i++) { + logger.error("" + error.read()); + } + return gpuUsage.toString(); + } + InputStream stdout = process.getInputStream(); + BufferedReader reader = + new BufferedReader(new InputStreamReader(stdout, StandardCharsets.UTF_8)); + String line; + String[] headers = new String[3]; + Boolean firstLine = true; + while ((line = reader.readLine()) != null) { + if (firstLine) { + headers = line.split(","); + firstLine = false; + } else { + String[] values = line.split(","); + StringBuffer sb = new StringBuffer("gpuId::" + gpuId + " "); + for (int i = 0; i < headers.length; i++) { + sb.append(headers[i] + "::" + values[i].strip()); + } + gpuUsage.append(sb.toString()); + } + } + } catch (Exception e) { + gpuUsage.append("failed to obtained gpu usage"); + logger.error("Exception Raised : " + e.toString()); + } + } else { + gpuUsage.append("N/A"); + } + + return gpuUsage.toString(); + } + public WorkerLifeCycle getLifeCycle() { return lifeCycle; }