From f1ca91d4e401096d04e962c982d62b1f2669c9f5 Mon Sep 17 00:00:00 2001 From: Tristan Konolige Date: Wed, 25 Aug 2021 18:25:29 -0700 Subject: [PATCH] [GRAPH EXECUTOR,VM] Add benchmarking function to graph executor and vm (#8807) * [GRAPH EXECUTOR,VM] Add benchmarking function to graph executor and vm This new benchmarking function is just a convenience function for calling time_evaluator on the underlying module. Hopefully this should make it easier for users to get good benchmarks of their code. * formatting * import order * more test, more comments, more precision * fix tests * add seconds descriptions to doc --- python/tvm/contrib/graph_executor.py | 59 +++++++++++++++ python/tvm/driver/tvmc/model.py | 27 ++----- python/tvm/driver/tvmc/runner.py | 8 +- python/tvm/runtime/module.py | 75 +++++++++++++++++-- python/tvm/runtime/vm.py | 64 ++++++++++++++++ src/runtime/rpc/rpc_module.cc | 5 +- tests/python/driver/tvmc/test_model.py | 3 +- tests/python/driver/tvmc/test_runner.py | 5 +- .../relay/test_backend_graph_executor.py | 26 +++++++ tests/python/relay/test_vm.py | 26 +++++++ tests/python/unittest/test_runtime_measure.py | 11 +++ tutorials/auto_scheduler/tune_network_arm.py | 6 +- tutorials/auto_scheduler/tune_network_cuda.py | 4 +- tutorials/auto_scheduler/tune_network_mali.py | 6 +- tutorials/auto_scheduler/tune_network_x86.py | 4 +- tutorials/autotvm/tune_relay_arm.py | 7 +- tutorials/autotvm/tune_relay_cuda.py | 7 +- tutorials/autotvm/tune_relay_mobile_gpu.py | 7 +- tutorials/autotvm/tune_relay_x86.py | 6 +- tutorials/frontend/deploy_model_on_android.py | 4 +- tutorials/frontend/deploy_prequantized.py | 4 +- .../frontend/deploy_prequantized_tflite.py | 4 +- tutorials/frontend/deploy_sparse.py | 7 +- 23 files changed, 283 insertions(+), 92 deletions(-) diff --git a/python/tvm/contrib/graph_executor.py b/python/tvm/contrib/graph_executor.py index f9d1b9734d45..2e8ff1d62421 100644 --- a/python/tvm/contrib/graph_executor.py +++ b/python/tvm/contrib/graph_executor.py @@ -320,3 +320,62 @@ def __getitem__(self, key): The key to the module. """ return self.module[key] + + def benchmark(self, device, func_name="run", repeat=5, number=5, min_repeat_ms=None, **kwargs): + """Calculate runtime of a function by repeatedly calling it. + + Use this function to get an accurate measurement of the runtime of a function. The function + is run multiple times in order to account for variability in measurements, processor speed + or other external factors. Mean, median, standard deviation, min and max runtime are all + reported. On GPUs, CUDA and ROCm specifically, special on-device timers are used so that + synchonization and data transfer operations are not counted towards the runtime. This allows + for fair comparison of runtimes across different functions and models. + + The benchmarking loop looks approximately like so: + + .. code-block:: python + + for r in range(repeat): + time_start = now() + for n in range(number): + func_name() + time_end = now() + total_times.append((time_end - time_start)/number) + + + Parameters + ---------- + func_name : str + The function to benchmark + + repeat : int + Number of times to run the outer loop of the timing code (see above). The output will + contain `repeat` number of datapoints. + + number : int + Number of times to run the inner loop of the timing code. This inner loop is run in + between the timer starting and stopping. In order to amortize any timing overhead, + `number` should be increased when the runtime of the function is small (less than a 1/10 + of a millisecond). + + min_repeat_ms : Optional[float] + If set, the inner loop will be run until it takes longer than `min_repeat_ms` + milliseconds. This can be used to ensure that the function is run enough to get an + accurate measurement. + + kwargs : Dict[str, Object] + Named arguments to the function. These are cached before running timing code, so that + data transfer costs are not counted in the runtime. + + Returns + ------- + timing_results : BenchmarkResult + Runtimes of the function. Use `.mean` to access the mean runtime, use `.results` to + access the individual runtimes (in seconds). + """ + min_repeat_ms = 0 if min_repeat_ms is None else min_repeat_ms + if kwargs: + self.set_input(**kwargs) + return self.module.time_evaluator( + func_name, device, repeat=repeat, number=number, min_repeat_ms=min_repeat_ms + )() diff --git a/python/tvm/driver/tvmc/model.py b/python/tvm/driver/tvmc/model.py index a9516e1e2c42..48bb052124ee 100644 --- a/python/tvm/driver/tvmc/model.py +++ b/python/tvm/driver/tvmc/model.py @@ -46,7 +46,7 @@ import os import tarfile import json -from typing import Optional, Union, List, Dict, Callable, TextIO +from typing import Optional, Union, Dict, Callable, TextIO import numpy as np import tvm @@ -54,6 +54,7 @@ from tvm import relay from tvm.contrib import utils from tvm.relay.backend.executor_factory import GraphExecutorFactoryModule +from tvm.runtime.module import BenchmarkResult try: from tvm.micro import export_model_library_format @@ -371,14 +372,14 @@ def import_package(self, package_path: str): class TVMCResult(object): """A class that stores the results of tvmc.run and provides helper utilities.""" - def __init__(self, outputs: Dict[str, np.ndarray], times: List[float]): + def __init__(self, outputs: Dict[str, np.ndarray], times: BenchmarkResult): """Create a convenience wrapper around the output of tvmc.run Parameters ---------- outputs : dict Outputs dictionary mapping the name of the output to its numpy value. - times : list of float + times : BenchmarkResult The execution times measured by the time evaluator in seconds to produce outputs. """ self.outputs = outputs @@ -390,29 +391,15 @@ def format_times(self): This has the effect of producing a small table that looks like: .. code-block:: Execution time summary: - mean (ms) max (ms) min (ms) std (ms) - 0.14310 0.16161 0.12933 0.01004 + mean (ms) median (ms) max (ms) min (ms) std (ms) + 0.14310 0.14310 0.16161 0.12933 0.01004 Returns ------- str A formatted string containing the statistics. """ - - # timestamps - mean_ts = np.mean(self.times) * 1000 - std_ts = np.std(self.times) * 1000 - max_ts = np.max(self.times) * 1000 - min_ts = np.min(self.times) * 1000 - - header = "Execution time summary:\n{0:^10} {1:^10} {2:^10} {3:^10}".format( - "mean (ms)", "max (ms)", "min (ms)", "std (ms)" - ) - stats = "{0:^10.2f} {1:^10.2f} {2:^10.2f} {3:^10.2f}".format( - mean_ts, max_ts, min_ts, std_ts - ) - - return "%s\n%s\n" % (header, stats) + return str(self.times) def get_output(self, name: str): """A helper function to grab one of the outputs by name. diff --git a/python/tvm/driver/tvmc/runner.py b/python/tvm/driver/tvmc/runner.py index 8515bc9b053c..489604d79cf4 100644 --- a/python/tvm/driver/tvmc/runner.py +++ b/python/tvm/driver/tvmc/runner.py @@ -421,12 +421,8 @@ def run_module( # This print is intentional print(report) - # create the module time evaluator (returns a function) - timer = module.module.time_evaluator("run", dev, number=number, repeat=repeat) - # call the evaluator function to invoke the module and save execution times - prof_result = timer() - # collect a list of execution times from the profiling results - times = prof_result.results + # call the benchmarking function of the executor + times = module.benchmark(dev, number=number, repeat=repeat) logger.debug("Collecting the output tensors.") num_outputs = module.get_num_outputs() diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 8107ab5b87d2..25a57bbb1c36 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -20,7 +20,8 @@ import os import ctypes import struct -from collections import namedtuple +from typing import Sequence +import numpy as np import tvm._ffi from tvm._ffi.base import _LIB, check_call, c_str, string_types, _RUNTIME_ONLY @@ -30,8 +31,69 @@ from . import _ffi_api -# profile result of time evaluator -ProfileResult = namedtuple("ProfileResult", ["mean", "results"]) +class BenchmarkResult: + """Runtimes from benchmarking""" + + def __init__(self, results: Sequence[float]): + """Construct a new BenchmarkResult from a sequence of runtimes. + + Parameters + ---------- + results : Sequence[float] + Raw times from benchmarking + + Attributes + ---------- + min : float + Minimum runtime in seconds of all results. + mean : float + Mean runtime in seconds of all results. If py:meth:`Module.time_evaluator` or + `benchmark` is called with `number` > 0, then each result is already the mean of a + `number` of runtimes, so this becomes the mean of means. + median : float + Median runtime in seconds of all results. If py:meth:`Module.time_evaluator` is called + with `number` > 0, then each result is already the mean of a `number` of runtimes, so + this becomes the median of means. + max : float + Maximum runtime in seconds of all results. If py:meth:`Module.time_evaluator` is called + with `number` > 0, then each result is already the mean of a `number` of runtimes, so + this becomes the maximum of those means. + std : float + Standard deviation in seconds of runtimes. If py:meth:`Module.time_evaluator` is called + with `number` > 0, then each result is already the mean of a `number` of runtimes, so + this becomes the standard deviation of means. + results : Sequence[float] + The collected runtimes (in seconds). This may be a series of mean runtimes if + py:meth:`Module.time_evaluator` or `benchmark` was run with `number` > 1. + """ + self.results = results + self.mean = np.mean(self.results) + self.std = np.std(self.results) + self.median = np.median(self.results) + self.min = np.min(self.results) + self.max = np.max(self.results) + + def __repr__(self): + return "BenchmarkResult(min={}, mean={}, median={}, max={}, std={}, results={})".format( + self.min, self.mean, self.median, self.max, self.std, self.results + ) + + def __str__(self): + return """Execution time summary: +{:^12} {:^12} {:^12} {:^12} {:^12} +{:^12.4f} {:^12.4f} {:^12.4f} {:^12.4f} {:^12.4f} + """.format( + "mean (ms)", + "median (ms)", + "max (ms)", + "min (ms)", + "std (ms)", + self.mean * 1000, + self.median * 1000, + self.max * 1000, + self.min * 1000, + self.std * 1000, + ) class Module(object): @@ -209,7 +271,7 @@ def time_evaluator(self, func_name, dev, number=10, repeat=1, min_repeat_ms=0, f Returns ------- ftimer : function - The function that takes same argument as func and returns a ProfileResult. + The function that takes same argument as func and returns a BenchmarkResult. The ProfileResult reports `repeat` time costs in seconds. """ try: @@ -230,12 +292,11 @@ def evaluator(*args): blob = feval(*args) fmt = "@" + ("d" * repeat) results = struct.unpack(fmt, blob) - mean = sum(results) / float(repeat) - return ProfileResult(mean=mean, results=results) + return BenchmarkResult(results) return evaluator except NameError: - raise NameError("time_evaluate is only supported when RPC is enabled") + raise NameError("time_evaluator is only supported when RPC is enabled") def _collect_from_import_tree(self, filter_func): """Helper function to collect modules from the tree matching a filter_func, then return it. diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index 2f133e1a422d..aeb651cb5ae4 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -507,3 +507,67 @@ def get_input_index(self, input_name, func_name="main"): The input index. -1 will be returned if the given input name is not found. """ return self._get_input_index(input_name, func_name) + + def benchmark( + self, device, *args, func_name="main", repeat=5, number=5, min_repeat_ms=None, **kwargs + ): + """Calculate runtime of a function by repeatedly calling it. + + Use this function to get an accurate measurement of the runtime of a function. The function + is run multiple times in order to account for variability in measurements, processor speed + or other external factors. Mean, median, standard deviation, min and max runtime are all + reported. On GPUs, CUDA and ROCm specifically, special on-device timers are used so that + synchonization and data transfer operations are not counted towards the runtime. This allows + for fair comparison of runtimes across different functions and models. + + The benchmarking loop looks approximately like so: + + .. code-block:: python + + for r in range(repeat): + time_start = now() + for n in range(number): + func_name() + time_end = now() + total_times.append((time_end - time_start)/number) + + + Parameters + ---------- + func_name : str + The function to benchmark + + repeat : int + Number of times to run the outer loop of the timing code (see above). The output will + contain `repeat` number of datapoints. + + number : int + Number of times to run the inner loop of the timing code. This inner loop is run in + between the timer starting and stopping. In order to amortize any timing overhead, + `number` should be increased when the runtime of the function is small (less than a 1/10 + of a millisecond). + + min_repeat_ms : Optional[float] + If set, the inner loop will be run until it takes longer than `min_repeat_ms` + milliseconds. This can be used to ensure that the function is run enough to get an + accurate measurement. + + args : Sequence[Object] + Arguments to the function. These are cached before running timing code, so that data + transfer costs are not counted in the runtime. + + kwargs : Dict[str, Object] + Named arguments to the function. These are cached like `args`. + + Returns + ------- + timing_results : BenchmarkResult + Runtimes of the function. Use `.mean` to access the mean runtime, use `.results` to + access the individual runtimes (in seconds). + """ + min_repeat_ms = 0 if min_repeat_ms is None else min_repeat_ms + if args or kwargs: + self.set_input(func_name, *args, **kwargs) + return self.module.time_evaluator( + "invoke", device, repeat=repeat, number=number, min_repeat_ms=min_repeat_ms + )(func_name) diff --git a/src/runtime/rpc/rpc_module.cc b/src/runtime/rpc/rpc_module.cc index 7272269680c5..b9ed54e73508 100644 --- a/src/runtime/rpc/rpc_module.cc +++ b/src/runtime/rpc/rpc_module.cc @@ -417,8 +417,9 @@ TVM_REGISTER_GLOBAL("runtime.RPCTimeEvaluator") << "Cannot find " << f_preproc_name << " in the global function"; f_preproc = *pf_preproc; } - return WrapTimeEvaluator(m.GetFunction(name, false), dev, number, repeat, min_repeat_ms, - f_preproc); + PackedFunc pf = m.GetFunction(name, false); + CHECK(pf != nullptr) << "Cannot find " << name << " in the global registry"; + return WrapTimeEvaluator(pf, dev, number, repeat, min_repeat_ms, f_preproc); } } else { auto* pf = runtime::Registry::Get(name); diff --git a/tests/python/driver/tvmc/test_model.py b/tests/python/driver/tvmc/test_model.py index f5a28d419cbb..fd2637a85f1f 100644 --- a/tests/python/driver/tvmc/test_model.py +++ b/tests/python/driver/tvmc/test_model.py @@ -21,6 +21,7 @@ from tvm.driver import tvmc from tvm.driver.tvmc.model import TVMCModel, TVMCPackage, TVMCResult +from tvm.runtime.module import BenchmarkResult def test_tvmc_workflow(keras_simple): @@ -35,7 +36,7 @@ def test_tvmc_workflow(keras_simple): assert type(result) is TVMCResult assert path.exists(tuning_records) assert type(result.outputs) is dict - assert type(result.times) is tuple + assert type(result.times) is BenchmarkResult assert "output_0" in result.outputs.keys() diff --git a/tests/python/driver/tvmc/test_runner.py b/tests/python/driver/tvmc/test_runner.py index 7acb376baba6..2ce363ab5911 100644 --- a/tests/python/driver/tvmc/test_runner.py +++ b/tests/python/driver/tvmc/test_runner.py @@ -20,6 +20,7 @@ from tvm.driver import tvmc from tvm.driver.tvmc.model import TVMCResult from tvm.driver.tvmc.result_utils import get_top_results +from tvm.runtime.module import BenchmarkResult def test_generate_tensor_data_zeros(): @@ -52,7 +53,7 @@ def test_generate_tensor_data__type_unknown(): def test_format_times__contains_header(): - fake_result = TVMCResult(outputs=None, times=[0.6, 1.2, 0.12, 0.42]) + fake_result = TVMCResult(outputs=None, times=BenchmarkResult([0.6, 1.2, 0.12, 0.42])) sut = fake_result.format_times() assert "std (ms)" in sut @@ -101,5 +102,5 @@ def test_run_tflite_module__with_profile__valid_input( tiger_cat_mobilenet_id in top_5_ids ), "tiger cat is expected in the top-5 for mobilenet v1" assert type(result.outputs) is dict - assert type(result.times) is tuple + assert type(result.times) is BenchmarkResult assert "output_0" in result.outputs.keys() diff --git a/tests/python/relay/test_backend_graph_executor.py b/tests/python/relay/test_backend_graph_executor.py index c6f2748e9ec8..9e212527838e 100644 --- a/tests/python/relay/test_backend_graph_executor.py +++ b/tests/python/relay/test_backend_graph_executor.py @@ -16,6 +16,7 @@ # under the License. import numpy as np import pytest +from unittest.mock import patch import tvm import json @@ -23,6 +24,7 @@ from tvm.contrib import graph_executor from tvm.relay.op import add import tvm.testing +from tvm.relay.testing import mlp # @tq, @jr should we put this in testing ns? def check_rts(expr, args, expected_result, mod=None): @@ -322,5 +324,29 @@ def test_graph_executor_api(): assert mod.get_input_index("Invalid") == -1 +@tvm.testing.requires_llvm +def test_benchmark(): + mod, params = mlp.get_workload(1) + lib = relay.build(mod, target="llvm", params=params) + exe = graph_executor.create(lib.get_graph_json(), lib.lib, tvm.cpu()) + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32")) + result = exe.benchmark(tvm.cpu(), data=data, func_name="run", repeat=2, number=1) + assert result.mean == result.median + assert result.mean > 0 + assert len(result.results) == 2 + + with patch.object( + tvm.runtime.module.Module, + "time_evaluator", + return_value=lambda: tvm.runtime.module.BenchmarkResult([1, 2, 2, 5]), + ) as method: + result = exe.benchmark(tvm.cpu(), data=data, func_name="run", repeat=2, number=1) + assert result.mean == 2.5 + assert result.median == 2.0 + assert result.max == 5 + assert result.min == 1 + assert result.std == 1.5 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index 7ae7e0eabeee..c7043481ee3d 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -17,6 +17,7 @@ import numpy as np import pytest import time +from unittest.mock import patch import tvm from tvm import runtime @@ -30,6 +31,7 @@ from tvm import rpc import tvm.testing from tvm.relay.transform import InferType +from tvm.relay.testing import mlp def check_result(args, expected_result, mod=None): @@ -955,5 +957,29 @@ def test_get_input_index(): assert vm_factory.get_input_index("invalid") == -1 +@tvm.testing.requires_llvm +def test_benchmark(): + mod, params = mlp.get_workload(1) + lib = vm.compile(mod, target="llvm", params=params) + exe = runtime.vm.VirtualMachine(lib, tvm.cpu()) + data = tvm.nd.array(np.random.rand(1, 1, 28, 28).astype("float32")) + result = exe.benchmark(tvm.cpu(), data, func_name="main", repeat=2, number=1) + assert result.mean == result.median + assert result.mean > 0 + assert len(result.results) == 2 + + with patch.object( + tvm.runtime.module.Module, + "time_evaluator", + return_value=lambda x: tvm.runtime.module.BenchmarkResult([1, 2, 2, 5]), + ) as method: + result = exe.benchmark(tvm.cpu(), data, func_name="main", repeat=2, number=1) + assert result.mean == 2.5 + assert result.median == 2.0 + assert result.max == 5 + assert result.min == 1 + assert result.std == 1.5 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/unittest/test_runtime_measure.py b/tests/python/unittest/test_runtime_measure.py index 0d02f910a44c..8955b03241a2 100644 --- a/tests/python/unittest/test_runtime_measure.py +++ b/tests/python/unittest/test_runtime_measure.py @@ -20,6 +20,7 @@ import tvm from tvm import te from tvm.contrib.utils import tempdir +from tvm.runtime.module import BenchmarkResult def test_min_repeat_ms(): @@ -56,5 +57,15 @@ def my_debug(filename): assert ct > 10 + 2 +def test_benchmark_result(): + r = BenchmarkResult([1, 2, 2, 5]) + assert r.mean == 2.5 + assert r.median == 2.0 + assert r.min == 1 + assert r.max == 5 + assert r.std == 1.5 + + if __name__ == "__main__": test_min_repeat_ms() + test_benchmark_result() diff --git a/tutorials/auto_scheduler/tune_network_arm.py b/tutorials/auto_scheduler/tune_network_arm.py index 5b0931405212..1619a55dc7e9 100644 --- a/tutorials/auto_scheduler/tune_network_arm.py +++ b/tutorials/auto_scheduler/tune_network_arm.py @@ -349,11 +349,7 @@ def tune_and_evaluate(): # Evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500) - prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) # We do not run the tuning in our webpage server since the server doesn't have a Raspberry Pi, diff --git a/tutorials/auto_scheduler/tune_network_cuda.py b/tutorials/auto_scheduler/tune_network_cuda.py index 7b5619c671be..08c15264e3c1 100644 --- a/tutorials/auto_scheduler/tune_network_cuda.py +++ b/tutorials/auto_scheduler/tune_network_cuda.py @@ -288,9 +288,7 @@ def run_tuning(): # Evaluate print("Evaluate inference time cost...") -ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500) -prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond -print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) +print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) ################################################################# diff --git a/tutorials/auto_scheduler/tune_network_mali.py b/tutorials/auto_scheduler/tune_network_mali.py index 8275f96806b8..2d1e51520952 100644 --- a/tutorials/auto_scheduler/tune_network_mali.py +++ b/tutorials/auto_scheduler/tune_network_mali.py @@ -264,11 +264,7 @@ def tune_and_evaluate(): # Evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500) - prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) # We do not run the tuning in our webpage server since server doesn't have mali gpu. diff --git a/tutorials/auto_scheduler/tune_network_x86.py b/tutorials/auto_scheduler/tune_network_x86.py index 76068fa79605..6cb8d6f14cb9 100644 --- a/tutorials/auto_scheduler/tune_network_x86.py +++ b/tutorials/auto_scheduler/tune_network_x86.py @@ -322,9 +322,7 @@ def run_tuning(): # Evaluate print("Evaluate inference time cost...") -ftimer = module.module.time_evaluator("run", dev, repeat=3, min_repeat_ms=500) -prof_res = np.array(ftimer().results) * 1e3 # convert to millisecond -print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) +print(module.benchmark(dev, repeat=3, min_repeat_ms=500)) ################################################################# diff --git a/tutorials/autotvm/tune_relay_arm.py b/tutorials/autotvm/tune_relay_arm.py index debf8b8ecf60..f072c5ddac93 100644 --- a/tutorials/autotvm/tune_relay_arm.py +++ b/tutorials/autotvm/tune_relay_arm.py @@ -359,12 +359,7 @@ def tune_and_evaluate(tuning_opt): # evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, number=1, repeat=10) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" - % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, number=1, repeat=10)) # We do not run the tuning in our webpage server since it takes too long. diff --git a/tutorials/autotvm/tune_relay_cuda.py b/tutorials/autotvm/tune_relay_cuda.py index 65991cc83454..b2af2e13f4fe 100644 --- a/tutorials/autotvm/tune_relay_cuda.py +++ b/tutorials/autotvm/tune_relay_cuda.py @@ -244,12 +244,7 @@ def tune_and_evaluate(tuning_opt): # evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, number=1, repeat=600) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" - % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, number=1, repeat=600)) # We do not run the tuning in our webpage server since it takes too long. diff --git a/tutorials/autotvm/tune_relay_mobile_gpu.py b/tutorials/autotvm/tune_relay_mobile_gpu.py index 790c2ff2c2b9..d3f4ec62fafc 100644 --- a/tutorials/autotvm/tune_relay_mobile_gpu.py +++ b/tutorials/autotvm/tune_relay_mobile_gpu.py @@ -352,12 +352,7 @@ def tune_and_evaluate(tuning_opt): # evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, number=1, repeat=30) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" - % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, number=1, repeat=30)) # We do not run the tuning in our webpage server since it takes too long. diff --git a/tutorials/autotvm/tune_relay_x86.py b/tutorials/autotvm/tune_relay_x86.py index 6b497ae9c0bd..771220bb3314 100644 --- a/tutorials/autotvm/tune_relay_x86.py +++ b/tutorials/autotvm/tune_relay_x86.py @@ -203,11 +203,7 @@ def evaluate_performance(lib, data_shape): # evaluate print("Evaluate inference time cost...") - ftimer = module.module.time_evaluator("run", dev, number=100, repeat=3) - prof_res = np.array(ftimer().results) * 1000 # convert to millisecond - print( - "Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res)) - ) + print(module.benchmark(dev, number=100, repeat=3)) def tune_and_evaluate(tuning_opt): diff --git a/tutorials/frontend/deploy_model_on_android.py b/tutorials/frontend/deploy_model_on_android.py index f435befb8250..c7b610d5d503 100644 --- a/tutorials/frontend/deploy_model_on_android.py +++ b/tutorials/frontend/deploy_model_on_android.py @@ -332,9 +332,7 @@ def transform_image(image): print("TVM prediction top-1: {}".format(synset[top1])) print("Evaluate inference time cost...") -ftimer = module.module.time_evaluator("run", dev, number=1, repeat=10) -prof_res = np.array(ftimer().results) * 1000 # convert to millisecond -print("Mean inference time (std dev): %.2f ms (%.2f ms)" % (np.mean(prof_res), np.std(prof_res))) +print(module.benchmark(dev, number=1, repeat=10)) ###################################################################### # Sample Output diff --git a/tutorials/frontend/deploy_prequantized.py b/tutorials/frontend/deploy_prequantized.py index a59655222278..11a9e3e3eee8 100644 --- a/tutorials/frontend/deploy_prequantized.py +++ b/tutorials/frontend/deploy_prequantized.py @@ -199,9 +199,7 @@ def quantize_model(model, inp): # Here we give an example of how to measure performance of TVM compiled models. n_repeat = 100 # should be bigger to make the measurement more accurate dev = tvm.cpu(0) -ftimer = rt_mod.module.time_evaluator("run", dev, number=1, repeat=n_repeat) -prof_res = np.array(ftimer().results) * 1e3 -print("Elapsed average ms:", np.mean(prof_res)) +print(rt_mod.benchmark(dev, number=1, repeat=n_repeat)) ###################################################################### # .. note:: diff --git a/tutorials/frontend/deploy_prequantized_tflite.py b/tutorials/frontend/deploy_prequantized_tflite.py index e3934e9b250f..7bbb06bdf801 100644 --- a/tutorials/frontend/deploy_prequantized_tflite.py +++ b/tutorials/frontend/deploy_prequantized_tflite.py @@ -232,9 +232,7 @@ def run_tvm(lib): # Here we give an example of how to measure performance of TVM compiled models. n_repeat = 100 # should be bigger to make the measurement more accurate dev = tvm.cpu(0) -ftimer = rt_mod.module.time_evaluator("run", dev, number=1, repeat=n_repeat) -prof_res = np.array(ftimer().results) * 1e3 -print("Elapsed average ms:", np.mean(prof_res)) +print(rt_mod.benchmark(dev, number=1, repeat=n_repeat)) ###################################################################### # .. note:: diff --git a/tutorials/frontend/deploy_sparse.py b/tutorials/frontend/deploy_sparse.py index f0af12b709e2..768a697f45cf 100644 --- a/tutorials/frontend/deploy_sparse.py +++ b/tutorials/frontend/deploy_sparse.py @@ -233,12 +233,7 @@ def run_relay_graph(mod, params, shape_dict, target, dev): m.run() tvm_output = m.get_output(0) - ftimer = m.module.time_evaluator("run", dev, repeat=5, number=5) - prof_res = np.array(ftimer().results) * 1000 - print( - "%-20s %-19s (%s)" - % ("Runtime:", "%.2f ms" % np.mean(prof_res), "%.2f ms" % np.std(prof_res)) - ) + print(m.benchmark(dev, repeat=5, number=5)) return tvm_output