diff --git a/python/hidet/graph/graph_utils/instruments/benchmark_instrument.py b/python/hidet/graph/graph_utils/instruments/benchmark_instrument.py index 0229bc1b3..a8f1756dd 100644 --- a/python/hidet/graph/graph_utils/instruments/benchmark_instrument.py +++ b/python/hidet/graph/graph_utils/instruments/benchmark_instrument.py @@ -13,7 +13,7 @@ import os import numpy as np -from hidet.runtime import CompiledModule +from hidet.runtime import CompiledTask from hidet.graph.flow_graph import FlowGraph, Operator, Tensor, GraphForwardInstrument @@ -57,10 +57,11 @@ def after_operator(self, op: Operator, inputs: List[Tensor], outputs: List[Tenso if not self.benchmarking: return - task_func: CompiledModule = op.compiled_task + task_func: CompiledTask = op.compiled_task latency: List[float] = task_func.profile( *inputs, *outputs, warmup=self.warmup, number=self.number, repeat=self.repeat ) + self.latency_list.append((op, float(np.median(latency)), float(np.std(latency)))) def after_graph(self, graph: FlowGraph, inputs: List[Tensor], outputs: List[Tensor]) -> None: diff --git a/python/hidet/runtime/compiled_task.py b/python/hidet/runtime/compiled_task.py index 48a0a55d7..720ce564a 100644 --- a/python/hidet/runtime/compiled_task.py +++ b/python/hidet/runtime/compiled_task.py @@ -182,6 +182,13 @@ def run_async(self, inputs): return outputs + def profile(self, *args, warmup=1, number=2, repeat=10): + num_outputs = len(self.meta_data.outputs) + inputs = args[:num_outputs] + outputs = args[num_outputs:] + candidate = self.candidates[self.pick_best_candidate(inputs, outputs)] + return candidate.profile(*args, warmup=warmup, number=number, repeat=repeat) + def load_compiled_task(compiled_task_dir: str) -> CompiledTask: return CompiledTask(compiled_task_dir)