From dea63a1b2a3d030b1b33392830e4545741071b1b Mon Sep 17 00:00:00 2001 From: Frank Liu Date: Sun, 1 Aug 2021 17:27:38 -0700 Subject: [PATCH] [djl-bench] Run benchmark on multiple GPUs Change-Id: Ie27e090699695526df42fefd01db2aba78dc3f73 --- .../ai/djl/benchmark/AbstractBenchmark.java | 18 ++++++++++------- .../main/java/ai/djl/benchmark/Arguments.java | 20 ++++++++++++++++++- .../main/java/ai/djl/benchmark/Benchmark.java | 5 ++++- .../djl/benchmark/MultithreadedBenchmark.java | 20 +++++++++++++------ 4 files changed, 48 insertions(+), 15 deletions(-) diff --git a/extensions/benchmark/src/main/java/ai/djl/benchmark/AbstractBenchmark.java b/extensions/benchmark/src/main/java/ai/djl/benchmark/AbstractBenchmark.java index 244105fc223b..d5df510b4be0 100644 --- a/extensions/benchmark/src/main/java/ai/djl/benchmark/AbstractBenchmark.java +++ b/extensions/benchmark/src/main/java/ai/djl/benchmark/AbstractBenchmark.java @@ -12,6 +12,7 @@ */ package ai.djl.benchmark; +import ai.djl.Device; import ai.djl.ModelException; import ai.djl.engine.Engine; import ai.djl.metric.Metrics; @@ -235,7 +236,7 @@ public final boolean runBenchmark(String[] args) { return false; } - protected ZooModel loadModel(Arguments arguments, Metrics metrics) + protected ZooModel loadModel(Arguments arguments, Metrics metrics, Device device) throws ModelException, IOException { long begin = System.nanoTime(); String artifactId = arguments.getArtifactId(); @@ -248,6 +249,7 @@ protected ZooModel loadModel(Arguments arguments, Metrics metrics .optModelUrls(arguments.getModelUrls()) .optModelName(arguments.getModelName()) .optEngine(arguments.getEngine()) + .optDevice(device) .optFilters(arguments.getCriteria()) .optArtifactId(artifactId) .optTranslator(translator) @@ -255,12 +257,14 @@ protected ZooModel loadModel(Arguments arguments, Metrics metrics .build(); ZooModel model = criteria.loadModel(); - long delta = System.nanoTime() - begin; - logger.info( - "Model {} loaded in: {} ms.", - model.getName(), - String.format("%.3f", delta / 1_000_000f)); - metrics.addMetric("LoadModel", delta); + if (device == Device.cpu() || device == Device.gpu()) { + long delta = System.nanoTime() - begin; + logger.info( + "Model {} loaded in: {} ms.", + model.getName(), + String.format("%.3f", delta / 1_000_000f)); + metrics.addMetric("LoadModel", delta); + } return model; } diff --git a/extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java b/extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java index 6f7e31a5ecc4..d79717eb2ed3 100644 --- a/extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java +++ b/extensions/benchmark/src/main/java/ai/djl/benchmark/Arguments.java @@ -44,6 +44,7 @@ public class Arguments { private int duration; private int iteration; private int threads; + private int maxGpus; private int delay; private PairList inputShapes; @@ -60,7 +61,7 @@ public class Arguments { try { modelUrls = path.toUri().toURL().toExternalForm(); } catch (IOException e) { - throw new IllegalArgumentException("Invalid model-path: " + modelUrls, e); + throw new IllegalArgumentException("Invalid model-path: " + modelPath, e); } } else if (cmd.hasOption("model-url")) { modelUrls = cmd.getOptionValue("model-url"); @@ -94,6 +95,12 @@ public class Arguments { } else { threads = Runtime.getRuntime().availableProcessors() * 2 - 1; } + if (cmd.hasOption("gpus")) { + maxGpus = Integer.parseInt(cmd.getOptionValue("gpus")); + if (maxGpus <= 0) { + maxGpus = Integer.MAX_VALUE; + } + } if (cmd.hasOption("criteria")) { Type type = new TypeToken>() {}.getType(); criteria = JsonUtils.GSON.fromJson(cmd.getOptionValue("criteria"), type); @@ -226,6 +233,13 @@ static Options getOptions() { .argName("NUMBER_THREADS") .desc("Number of inference threads.") .build()); + options.addOption( + Option.builder("g") + .longOpt("gpus") + .hasArg() + .argName("NUMBER_GPUS") + .desc("Number of GPUS to run multithreading inference.") + .build()); options.addOption( Option.builder("l") .longOpt("delay") @@ -283,6 +297,10 @@ int getThreads() { return threads; } + int getMaxGpus() { + return maxGpus; + } + String getOutputDir() { if (outputDir == null) { outputDir = "build"; diff --git a/extensions/benchmark/src/main/java/ai/djl/benchmark/Benchmark.java b/extensions/benchmark/src/main/java/ai/djl/benchmark/Benchmark.java index 6cedd9eb0ebb..91cebbcd30c4 100644 --- a/extensions/benchmark/src/main/java/ai/djl/benchmark/Benchmark.java +++ b/extensions/benchmark/src/main/java/ai/djl/benchmark/Benchmark.java @@ -12,7 +12,9 @@ */ package ai.djl.benchmark; +import ai.djl.Device; import ai.djl.ModelException; +import ai.djl.engine.Engine; import ai.djl.inference.Predictor; import ai.djl.metric.Metrics; import ai.djl.repository.zoo.ZooModel; @@ -56,7 +58,8 @@ public static void main(String[] args) { @Override public float[] predict(Arguments arguments, Metrics metrics, int iteration) throws IOException, ModelException, TranslateException { - try (ZooModel model = loadModel(arguments, metrics)) { + Device device = Engine.getEngine(arguments.getEngine()).defaultDevice(); + try (ZooModel model = loadModel(arguments, metrics, device)) { float[] predictResult = null; try (Predictor predictor = model.newPredictor()) { predictor.setMetrics(metrics); // Let predictor collect metrics diff --git a/extensions/benchmark/src/main/java/ai/djl/benchmark/MultithreadedBenchmark.java b/extensions/benchmark/src/main/java/ai/djl/benchmark/MultithreadedBenchmark.java index 0c12ce093ba5..7294cb63ac99 100644 --- a/extensions/benchmark/src/main/java/ai/djl/benchmark/MultithreadedBenchmark.java +++ b/extensions/benchmark/src/main/java/ai/djl/benchmark/MultithreadedBenchmark.java @@ -12,7 +12,9 @@ */ package ai.djl.benchmark; +import ai.djl.Device; import ai.djl.ModelException; +import ai.djl.engine.Engine; import ai.djl.inference.Predictor; import ai.djl.metric.Metrics; import ai.djl.repository.zoo.ZooModel; @@ -41,16 +43,22 @@ public float[] predict(Arguments arguments, Metrics metrics, int iteration) MemoryTrainingListener.collectMemoryInfo(metrics); // Measure memory before loading model - ZooModel model = loadModel(arguments, metrics); - - int numOfThreads = arguments.getThreads(); + Engine engine = Engine.getEngine(arguments.getEngine()); + Device[] devices = engine.getDevices(arguments.getMaxGpus()); + int numOfThreads = arguments.getThreads() * devices.length; int delay = arguments.getDelay(); AtomicInteger counter = new AtomicInteger(iteration); logger.info("Multithreading inference with {} threads.", numOfThreads); + List> models = new ArrayList<>(devices.length); List callables = new ArrayList<>(numOfThreads); - for (int i = 0; i < numOfThreads; ++i) { - callables.add(new PredictorCallable(model, metrics, counter, i, i == 0)); + for (Device device : devices) { + ZooModel model = loadModel(arguments, metrics, device); + models.add(model); + + for (int i = 0; i < numOfThreads; ++i) { + callables.add(new PredictorCallable(model, metrics, counter, i, i == 0)); + } } float[] result = null; @@ -89,7 +97,7 @@ public float[] predict(Arguments arguments, Metrics metrics, int iteration) executorService.shutdown(); } - model.close(); + models.forEach(ZooModel::close); if (successThreads != numOfThreads) { logger.error("Only {}/{} threads finished.", successThreads, numOfThreads); return null;