From bef4c88e4836496a0fa73fc34d6768a65fb83c21 Mon Sep 17 00:00:00 2001 From: Wojciech Uss Date: Fri, 25 Sep 2020 05:09:42 +0200 Subject: [PATCH] Add an option to set number of warmup iterations --- .../fluid/inference/tests/api/tester_helper.h | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/inference/tests/api/tester_helper.h b/paddle/fluid/inference/tests/api/tester_helper.h index 723e989be8de8..252bca2d5522e 100644 --- a/paddle/fluid/inference/tests/api/tester_helper.h +++ b/paddle/fluid/inference/tests/api/tester_helper.h @@ -65,6 +65,7 @@ DEFINE_bool(zero_copy, false, "Use ZeroCopy to speedup Feed/Fetch."); DEFINE_bool(warmup, false, "Use warmup to calculate elapsed_time more accurately. " "To reduce CI time, it sets false in default."); +DEFINE_int32(warmup_iters, 1, "Number of batches to process during warmup."); DEFINE_bool(enable_profile, false, "Turn on profiler for fluid"); DEFINE_int32(cpu_num_threads, 1, "Number of threads for each paddle instance."); @@ -364,15 +365,28 @@ void PredictionWarmUp(PaddlePredictor *predictor, if (FLAGS_zero_copy) { ConvertPaddleTensorToZeroCopyTensor(predictor, inputs[0]); } - outputs->resize(1); + int iterations = 1; + if (FLAGS_warmup_iters > 1) + iterations = std::min(FLAGS_warmup_iters, static_cast(inputs.size())); + outputs->resize(iterations); Timer warmup_timer; - warmup_timer.tic(); + double elapsed_time = 0; if (!FLAGS_zero_copy) { - predictor->Run(inputs[0], &(*outputs)[0], batch_size); + for (int i = 0; i < iterations; ++i) { + warmup_timer.tic(); + predictor->Run(inputs[i], &(*outputs)[i], batch_size); + elapsed_time += warmup_timer.toc(); + } } else { - predictor->ZeroCopyRun(); + for (int i = 0; i < iterations; ++i) { + warmup_timer.tic(); + predictor->ZeroCopyRun(); + elapsed_time += warmup_timer.toc(); + } } - PrintTime(batch_size, 1, num_threads, tid, warmup_timer.toc(), 1, data_type); + auto batch_latency = elapsed_time / iterations; + PrintTime(batch_size, 1, num_threads, tid, batch_latency, iterations, + data_type); if (FLAGS_enable_profile) { paddle::platform::ResetProfiler(); }