Skip to content

Commit

Permalink
[djl-bench] Add warmup to benchmark (#1152)
Browse files Browse the repository at this point in the history
Change-Id: I3f5e5ec3bc6d9c4f2c4ad09c6695775f50981ebd
  • Loading branch information
frankfliu authored Aug 9, 2021
1 parent a6f0cfc commit f5fffd2
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,14 @@ public final boolean runBenchmark(String[] args) {
while (!duration.isNegative()) {
Metrics metrics = new Metrics(); // Reset Metrics for each test loop.
progressBar = new ProgressBar("Iteration", iteration);
long begin = System.currentTimeMillis();
float[] lastResult = predict(arguments, metrics, iteration);
if (lastResult == null) {
return false;
}

if (metrics.hasMetric("mt_start")) {
begin = metrics.getMetric("mt_start").get(0).getValue().longValue();
}
long totalTime = System.currentTimeMillis() - begin;
long begin = metrics.getMetric("start").get(0).getValue().longValue();
long end = metrics.getMetric("end").get(0).getValue().longValue();
long totalTime = end - begin;

if (lastResult.length > 3) {
logger.info(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,19 @@ public float[] predict(Arguments arguments, Metrics metrics, int iteration)
Device device = Engine.getEngine(arguments.getEngine()).defaultDevice();
try (ZooModel<Void, float[]> model = loadModel(arguments, metrics, device)) {
float[] predictResult = null;

try (Predictor<Void, float[]> predictor = model.newPredictor()) {
predictor.setMetrics(metrics); // Let predictor collect metrics
predictor.predict(null); // warmup

predictor.setMetrics(metrics); // Let predictor collect metrics
metrics.addMetric("start", System.currentTimeMillis(), "mills");
for (int i = 0; i < iteration; ++i) {
predictResult = predictor.predict(null);

progressBar.update(i);
MemoryTrainingListener.collectMemoryInfo(metrics);
}
metrics.addMetric("end", System.currentTimeMillis(), "mills");
}
return predictResult;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import ai.djl.metric.Metrics;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.listener.MemoryTrainingListener;
import ai.djl.translate.TranslateException;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
Expand All @@ -39,7 +40,7 @@ public class MultithreadedBenchmark extends AbstractBenchmark {
/** {@inheritDoc} */
@Override
public float[] predict(Arguments arguments, Metrics metrics, int iteration)
throws IOException, ModelException {
throws IOException, ModelException, TranslateException {

MemoryTrainingListener.collectMemoryInfo(metrics); // Measure memory before loading model

Expand Down Expand Up @@ -74,7 +75,11 @@ public float[] predict(Arguments arguments, Metrics metrics, int iteration)

int successThreads = 0;
try {
metrics.addMetric("mt_start", System.currentTimeMillis(), "mills");
for (PredictorCallable callable : callables) {
callable.warmup();
}

metrics.addMetric("start", System.currentTimeMillis(), "mills");
try {
List<Future<float[]>> futures;
if (delay > 0) {
Expand All @@ -96,6 +101,7 @@ public float[] predict(Arguments arguments, Metrics metrics, int iteration)
} catch (InterruptedException | ExecutionException e) {
logger.error("", e);
}
metrics.addMetric("end", System.currentTimeMillis(), "mills");
for (PredictorCallable callable : callables) {
callable.close();
}
Expand Down Expand Up @@ -170,6 +176,10 @@ public float[] call() throws Exception {
return result;
}

public void warmup() throws TranslateException {
predictor.predict(null);
}

public void close() {
predictor.close();
}
Expand Down

0 comments on commit f5fffd2

Please sign in to comment.