diff --git a/python/hidet/graph/flow_graph.py b/python/hidet/graph/flow_graph.py index cbb8dec69..8d7e642a4 100644 --- a/python/hidet/graph/flow_graph.py +++ b/python/hidet/graph/flow_graph.py @@ -25,6 +25,7 @@ from hidet.ir.task import Task from hidet.graph.tensor import Tensor, zeros_like, randn_like from hidet.graph.operator import Operator, SymbolVar +from hidet.utils.benchmark import do_bench logger = logging.getLogger(__name__) @@ -395,7 +396,7 @@ def f_run(inputs: List[Tensor]) -> List[Tensor]: return CudaGraph(f_create_inputs, f_run, ref_objs=[self]) def latency( - self, warmup=1, number=3, repeat=3, median=True, dummy_inputs: Optional[Sequence[Tensor]] = None + self, warmup=25, repeat=100, dummy_inputs: Optional[Sequence[Tensor]] = None ) -> Union[float, List[float]]: """Measure the latency of the flow graph. @@ -404,15 +405,9 @@ def latency( warmup: int The number of warmup runs. - number: int - The number of runs to measure the latency. - repeat: int The number of times to repeat the measurement. - median: bool - Whether to return the median latency. - dummy_inputs: Optional[Sequence[Tensor]] The dummy inputs to run the flow graph. If not given, automatic generated dummy inputs would be used. @@ -421,26 +416,12 @@ def latency( ret: Union[float, List[float]] The measured latency in milliseconds. """ - import time - import numpy as np if dummy_inputs is None: dummy_inputs = self.dummy_inputs() - for _ in range(warmup): - self.forward(dummy_inputs) - results = [] - for _ in range(repeat): - hidet.cuda.synchronize() - t1 = time.time() - for _ in range(number): - self.forward(dummy_inputs) - hidet.cuda.synchronize() - t2 = time.time() - results.append((t2 - t1) * 1000 / number) - if median: - return float(np.median(results)) - else: - return results + + # return the median + return do_bench(lambda: self.forward(dummy_inputs), warmup=warmup, rep=repeat)[1] @staticmethod def _analyze(outputs: List[Tensor]) -> Tuple[List[Tensor], List[Operator], Dict[Tensor, int]]: diff --git a/python/hidet/utils/benchmark/bench.py b/python/hidet/utils/benchmark/bench.py index 0d6afd1d6..72731d1b4 100644 --- a/python/hidet/utils/benchmark/bench.py +++ b/python/hidet/utils/benchmark/bench.py @@ -33,47 +33,40 @@ def do_bench(fn, warmup=25, rep=100, percentiles=(0.2, 0.5, 0.8)): """ # Estimate the runtime of the function - import torch + import hidet fn() - torch.cuda.synchronize() - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) + hidet.cuda.synchronize() + start_event = hidet.cuda.Event(enable_timing=True) + end_event = hidet.cuda.Event(enable_timing=True) start_event.record() for _ in range(5): fn() end_event.record() - torch.cuda.synchronize() - estimate_ms = start_event.elapsed_time(end_event) / 5 - # compute number of warmup and repeat + hidet.cuda.synchronize() + estimate_ms = end_event.elapsed_time(start_event) / 5 n_warmup = max(1, int(warmup / estimate_ms)) n_repeat = max(1, int(rep / estimate_ms)) - # We maintain a buffer of 256 MB that we clear - # before each kernel call to make sure that the L2 - # doesn't contain any input data before the run - start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] - cache = torch.empty(int(256e6), dtype=torch.int8, device='cuda') + start_event = [hidet.cuda.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [hidet.cuda.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up for _ in range(n_warmup): fn() # Benchmark for i in range(n_repeat): - # we clear the L2 cache before each run - cache.zero_() - # record time of `fn` start_event[i].record() fn() end_event[i].record() # Record clocks - torch.cuda.synchronize() - times = torch.tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)]) + hidet.cuda.synchronize() + times = np.array([e.elapsed_time(s) for s, e in zip(start_event, end_event)]) if percentiles: - percentiles = torch.quantile(times, torch.tensor(percentiles)).tolist() + percentiles = np.quantile(times, percentiles) return tuple(percentiles) else: - return torch.mean(times).item() + return np.mean(times).item() def benchmark_func(run_func, warmup=1, number=5, repeat=5, median=True) -> Union[List[float], float]: diff --git a/scripts/regression/op_performance.py b/scripts/regression/op_performance.py index 047ae02d7..21012347d 100644 --- a/scripts/regression/op_performance.py +++ b/scripts/regression/op_performance.py @@ -15,7 +15,7 @@ def bench_matmul(m, n, k, dtype): c = hidet.ops.matmul(a, b) g = hidet.trace_from(c, [a, b]) g = hidet.graph.optimize(g) - return g.latency(warmup=10, number=5, repeat=100) + return g.latency(warmup=25, repeat=100) def bench_fmha(sq, skv, d): hidet.option.search_space(2) @@ -25,7 +25,7 @@ def bench_fmha(sq, skv, d): o = hidet.ops.attention(q, k, v) g = hidet.trace_from(o, [q, k, v]) g = hidet.graph.optimize(g) - return g.latency(warmup=10, number=5, repeat=100) + return g.latency(warmup=25, repeat=100) def matmul_regression() -> ResultGroup: regression_data = load_regression_data()