Skip to content

Commit

Permalink
[djl-bench] Run benchmark on multiple GPUs
Browse files Browse the repository at this point in the history
Change-Id: Ie27e090699695526df42fefd01db2aba78dc3f73
  • Loading branch information
frankfliu committed Aug 4, 2021
1 parent 7b5bc04 commit dea63a1
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
*/
package ai.djl.benchmark;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.metric.Metrics;
Expand Down Expand Up @@ -235,7 +236,7 @@ public final boolean runBenchmark(String[] args) {
return false;
}

protected ZooModel<Void, float[]> loadModel(Arguments arguments, Metrics metrics)
protected ZooModel<Void, float[]> loadModel(Arguments arguments, Metrics metrics, Device device)
throws ModelException, IOException {
long begin = System.nanoTime();
String artifactId = arguments.getArtifactId();
Expand All @@ -248,19 +249,22 @@ protected ZooModel<Void, float[]> loadModel(Arguments arguments, Metrics metrics
.optModelUrls(arguments.getModelUrls())
.optModelName(arguments.getModelName())
.optEngine(arguments.getEngine())
.optDevice(device)
.optFilters(arguments.getCriteria())
.optArtifactId(artifactId)
.optTranslator(translator)
.optProgress(new ProgressBar())
.build();

ZooModel<Void, float[]> model = criteria.loadModel();
long delta = System.nanoTime() - begin;
logger.info(
"Model {} loaded in: {} ms.",
model.getName(),
String.format("%.3f", delta / 1_000_000f));
metrics.addMetric("LoadModel", delta);
if (device == Device.cpu() || device == Device.gpu()) {
long delta = System.nanoTime() - begin;
logger.info(
"Model {} loaded in: {} ms.",
model.getName(),
String.format("%.3f", delta / 1_000_000f));
metrics.addMetric("LoadModel", delta);
}
return model;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public class Arguments {
private int duration;
private int iteration;
private int threads;
private int maxGpus;
private int delay;
private PairList<DataType, Shape> inputShapes;

Expand All @@ -60,7 +61,7 @@ public class Arguments {
try {
modelUrls = path.toUri().toURL().toExternalForm();
} catch (IOException e) {
throw new IllegalArgumentException("Invalid model-path: " + modelUrls, e);
throw new IllegalArgumentException("Invalid model-path: " + modelPath, e);
}
} else if (cmd.hasOption("model-url")) {
modelUrls = cmd.getOptionValue("model-url");
Expand Down Expand Up @@ -94,6 +95,12 @@ public class Arguments {
} else {
threads = Runtime.getRuntime().availableProcessors() * 2 - 1;
}
if (cmd.hasOption("gpus")) {
maxGpus = Integer.parseInt(cmd.getOptionValue("gpus"));
if (maxGpus <= 0) {
maxGpus = Integer.MAX_VALUE;
}
}
if (cmd.hasOption("criteria")) {
Type type = new TypeToken<Map<String, String>>() {}.getType();
criteria = JsonUtils.GSON.fromJson(cmd.getOptionValue("criteria"), type);
Expand Down Expand Up @@ -226,6 +233,13 @@ static Options getOptions() {
.argName("NUMBER_THREADS")
.desc("Number of inference threads.")
.build());
options.addOption(
Option.builder("g")
.longOpt("gpus")
.hasArg()
.argName("NUMBER_GPUS")
.desc("Number of GPUS to run multithreading inference.")
.build());
options.addOption(
Option.builder("l")
.longOpt("delay")
Expand Down Expand Up @@ -283,6 +297,10 @@ int getThreads() {
return threads;
}

int getMaxGpus() {
return maxGpus;
}

String getOutputDir() {
if (outputDir == null) {
outputDir = "build";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
*/
package ai.djl.benchmark;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.repository.zoo.ZooModel;
Expand Down Expand Up @@ -56,7 +58,8 @@ public static void main(String[] args) {
@Override
public float[] predict(Arguments arguments, Metrics metrics, int iteration)
throws IOException, ModelException, TranslateException {
try (ZooModel<Void, float[]> model = loadModel(arguments, metrics)) {
Device device = Engine.getEngine(arguments.getEngine()).defaultDevice();
try (ZooModel<Void, float[]> model = loadModel(arguments, metrics, device)) {
float[] predictResult = null;
try (Predictor<Void, float[]> predictor = model.newPredictor()) {
predictor.setMetrics(metrics); // Let predictor collect metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
*/
package ai.djl.benchmark;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.metric.Metrics;
import ai.djl.repository.zoo.ZooModel;
Expand Down Expand Up @@ -41,16 +43,22 @@ public float[] predict(Arguments arguments, Metrics metrics, int iteration)

MemoryTrainingListener.collectMemoryInfo(metrics); // Measure memory before loading model

ZooModel<Void, float[]> model = loadModel(arguments, metrics);

int numOfThreads = arguments.getThreads();
Engine engine = Engine.getEngine(arguments.getEngine());
Device[] devices = engine.getDevices(arguments.getMaxGpus());
int numOfThreads = arguments.getThreads() * devices.length;
int delay = arguments.getDelay();
AtomicInteger counter = new AtomicInteger(iteration);
logger.info("Multithreading inference with {} threads.", numOfThreads);

List<ZooModel<Void, float[]>> models = new ArrayList<>(devices.length);
List<PredictorCallable> callables = new ArrayList<>(numOfThreads);
for (int i = 0; i < numOfThreads; ++i) {
callables.add(new PredictorCallable(model, metrics, counter, i, i == 0));
for (Device device : devices) {
ZooModel<Void, float[]> model = loadModel(arguments, metrics, device);
models.add(model);

for (int i = 0; i < numOfThreads; ++i) {
callables.add(new PredictorCallable(model, metrics, counter, i, i == 0));
}
}

float[] result = null;
Expand Down Expand Up @@ -89,7 +97,7 @@ public float[] predict(Arguments arguments, Metrics metrics, int iteration)
executorService.shutdown();
}

model.close();
models.forEach(ZooModel::close);
if (successThreads != numOfThreads) {
logger.error("Only {}/{} threads finished.", successThreads, numOfThreads);
return null;
Expand Down

0 comments on commit dea63a1

Please sign in to comment.