diff --git a/python/hidet/drivers/build_task.py b/python/hidet/drivers/build_task.py index 1f18b5a8f..8aa0cd2a6 100644 --- a/python/hidet/drivers/build_task.py +++ b/python/hidet/drivers/build_task.py @@ -188,8 +188,17 @@ def get_signature(t: TensorNode, device: str) -> TensorSignature: device=device, dtype=t.type.dtype.name, shape=[int(v) if is_constant(v) else str(v) for v in t.shape] ) + # extract the task name + from hidet.graph.ops.fusion.fused_operator import FusedTask + + if isinstance(task, FusedTask): + task_name = 'fused_{}'.format(task.attrs['fused_ops'].replace(' ', '_')) + else: + task_name = task.name + # generate meta data meta = TaskMetaData( + name=task_name, symbols=[v.name for v in task.symbols], inputs=[get_signature(t, input_device) for t in task.inputs], outputs=[get_signature(t, output_device) for t in task.outputs], diff --git a/python/hidet/ir/builders/stmt_builder.py b/python/hidet/ir/builders/stmt_builder.py index 2dc1fbb27..9903892d5 100644 --- a/python/hidet/ir/builders/stmt_builder.py +++ b/python/hidet/ir/builders/stmt_builder.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Optional, Sequence +from typing import Union, Sequence from hidet.ir.stmt import Stmt, ForStmt, IfStmt, EvaluateStmt, SeqStmt, LetStmt, ForMappingStmt, ForStmtAttr from hidet.ir.expr import Expr, Var, var, convert @@ -65,10 +65,10 @@ def lets(self, bind_vars: Sequence[Union[str, Var]], values: Sequence[Union[int, seq_let_stmt = LetStmt(bind_vars, bind_values, body=1) return StmtScope(self, stmts=seq_let_stmt, ret=bind_vars) - def for_loop(self, v: Union[str, Var], extent: Union[int, Expr], unroll: Optional[bool] = None) -> StmtScope: + def for_loop(self, v: Union[str, Var], extent: Union[int, Expr], attr: str = '.') -> StmtScope: if isinstance(v, str): v = var(v) - return StmtScope(self, stmts=ForStmt(v, extent, attr=ForStmtAttr(unroll)), ret=v) + return StmtScope(self, stmts=ForStmt(v, extent, attr=ForStmtAttr.parse(attr, num_loops=1)[0]), ret=v) def if_then(self, cond: Union[bool, Expr]) -> StmtScope: return StmtScope(self, stmts=[IfStmt(cond)], ret=None) diff --git a/python/hidet/ir/primitives/cuda/mma.py b/python/hidet/ir/primitives/cuda/mma.py index 2f574639d..f927dbd48 100644 --- a/python/hidet/ir/primitives/cuda/mma.py +++ b/python/hidet/ir/primitives/cuda/mma.py @@ -357,8 +357,8 @@ def _print_segment(mapping: TaskMapping, dtype: DataType, addr: Expr, worker_id: if msg: with sb.if_then(worker_id == 0): sb += printf(f'{msg}\\n') - with sb.for_loop('i', mapping.task_shape[0], unroll=False) as i: - with sb.for_loop('j', mapping.task_shape[1], unroll=False) as j: + with sb.for_loop('i', mapping.task_shape[0]) as i: + with sb.for_loop('j', mapping.task_shape[1]) as j: p = var('p', int32) sb += DeclareStmt(p, int32(0)) with sb.for_mapping(['ii', 'jj'], mapping, worker_id) as (ii, jj): diff --git a/python/hidet/ir/schedulers/cpu/scheduler.py b/python/hidet/ir/schedulers/cpu/scheduler.py index 027bd2561..9089c288c 100644 --- a/python/hidet/ir/schedulers/cpu/scheduler.py +++ b/python/hidet/ir/schedulers/cpu/scheduler.py @@ -13,17 +13,16 @@ from hidet.ir.builders import FunctionBuilder from hidet.ir.compute import TensorNode, GridCompute -from hidet.ir.expr import Var, convert, call +from hidet.ir.expr import Var, call from hidet.ir.tools import rewrite from hidet.ir.stmt import Stmt, BufferStoreStmt, EvaluateStmt from hidet.ir.schedulers.base import AutoScheduler, ComputeExprLower +from hidet.ir.mapping import row_spatial +from hidet.utils.py import prod class CpuAutoScheduler(AutoScheduler): def schedule_grid_compute(self, node: GridCompute, tensor_map: Dict[TensorNode, Var]) -> Stmt: - # pylint: disable=too-many-locals, import-outside-toplevel, unnecessary-comprehension - from hidet.ir.mapping import row_repeat, TaskMapping - params, param_map, call_args = self.grid_compute_params_and_args(node, tensor_map) if self.task is not None: @@ -35,16 +34,16 @@ def schedule_grid_compute(self, node: GridCompute, tensor_map: Dict[TensorNode, # set function parameters fb.extend_params(params) - mapping: TaskMapping = row_repeat(*node.shape) iter_names = [f'i{i}' for i in range(len(node.shape))] - with fb.for_mapping(iter_names, mapping, convert(0)) as task_index: - out_param: Var = param_map[node] - compute_lower = ComputeExprLower(node.value, param_map=param_map) - stmts, value = compute_lower.lower() - rmap = {axis: axis_value for axis, axis_value in zip(node.axes, task_index)} - stmts, value = rewrite([stmts, value], rmap) - fb += stmts - fb += BufferStoreStmt(out_param, task_index, value) + with fb.for_loop('w', extent=prod(node.shape), attr='p') as w: + with fb.for_mapping(iter_names, row_spatial(*node.shape), worker=w) as task_index: + out_param: Var = param_map[node] + compute_lower = ComputeExprLower(node.value, param_map=param_map) + stmts, value = compute_lower.lower() + rmap = {axis: axis_value for axis, axis_value in zip(node.axes, task_index)} + stmts, value = rewrite([stmts, value], rmap) + fb += stmts + fb += BufferStoreStmt(out_param, task_index, value) func = fb.get() func_var = self.add_function(func) diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index 0b5eeee53..dd86f0e03 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -28,7 +28,8 @@ from hidet.runtime.compiled_task import CompiledTask, TensorSignature, _check_inputs from hidet.runtime.storage import Storage from hidet.ffi import runtime_api -from hidet.utils import prod +from hidet.utils.py import prod, median +from hidet.utils.trace_utils import TraceEventEmitter ModelExecutionHook = Callable[[int, List['Tensor'], List['Tensor']], None] @@ -97,6 +98,7 @@ def __init__( self.constant_outputs: List[Union[None, Tensor]] = [] # runtime state + self.working_dir: str = hidet.utils.cache_file('graphs', self.meta.graph_hash) self.dispatch_table_path = hidet.utils.cache_file('graphs', self.meta.graph_hash, 'dispatch_table.txt') self.dispatch_table: Dict[Tuple[int, ...], Array] = {} self.cuda_workspace: Optional[Storage] = None @@ -258,23 +260,49 @@ def _run_slow_path(self, inputs, symbol_dims: Tuple[int, ...]): index2tensor[exe.inputs_index[i]] = inputs[i] for i in range(len(self.weights)): index2tensor[exe.weights_index[i]] = self.weights[i] + best_candidates = [-1 for _ in range(len(self.compiled_tasks))] + trace_emitter = TraceEventEmitter({'graph': self.graph_string}) for inst in exe.instructions: + # prepare inputs and kernel node_inputs = [index2tensor[i] for i in inst.inputs] node_kernel: CompiledTask = self.compiled_tasks[inst.task_idx] + + # run the kernel node_outputs = node_kernel.run_async(node_inputs) + + # record outputs for i, output_index in enumerate(inst.outputs): index2tensor[output_index] = node_outputs[i] + # record best candidate for this kernel best_candidates[inst.task_idx] = node_kernel.pick_best_candidate(node_inputs, node_outputs) + # record trace events + trace_emitter.append( + name=node_kernel.meta_data.name, + duration_us=int(median(node_kernel.profile(*node_inputs, *node_outputs)) * 1000), + args={ + 'name': node_kernel.meta_data.name, + 'inputs': ['{}{}'.format(x.dtype, x.shape) for x in node_kernel.meta_data.inputs], + 'outputs': ['{}{}'.format(x.dtype, x.shape) for x in node_kernel.meta_data.outputs], + }, + ) + + # free tensors that are no longer needed for idx in inst.free: del index2tensor[idx] outputs = [index2tensor[i] for i in exe.outputs_index] + # update the dispatch table self._update_symbol_table(symbol_dims, best_candidates) + # save the trace + trace_filename = 'trace{}.json'.format('_'.join(str(x) for x in symbol_dims)) + with open(os.path.join(self.working_dir, trace_filename), 'w') as f: + trace_emitter.save(f) + return outputs def run_async(self, inputs): diff --git a/python/hidet/runtime/compiled_task.py b/python/hidet/runtime/compiled_task.py index 720ce564a..021c8196e 100644 --- a/python/hidet/runtime/compiled_task.py +++ b/python/hidet/runtime/compiled_task.py @@ -31,6 +31,7 @@ class TensorSignature: @dataclass class TaskMetaData: + name: str symbols: List[str] inputs: List[TensorSignature] outputs: List[TensorSignature] diff --git a/python/hidet/utils/py.py b/python/hidet/utils/py.py index 780a2539b..f0113e586 100644 --- a/python/hidet/utils/py.py +++ b/python/hidet/utils/py.py @@ -37,6 +37,14 @@ def prod(seq: Iterable): return c +def median(seq: Iterable): + seq = list(seq) + if len(seq) == 0: + return None + else: + return sorted(seq)[len(seq) // 2] + + def clip( x: Union[int, float], low: Optional[Union[int, float]], high: Optional[Union[int, float]] ) -> Union[int, float]: diff --git a/python/hidet/utils/trace_utils.py b/python/hidet/utils/trace_utils.py new file mode 100644 index 000000000..eab5f08d5 --- /dev/null +++ b/python/hidet/utils/trace_utils.py @@ -0,0 +1,55 @@ +from typing import Any, Dict, List +from dataclasses import dataclass, asdict +import json + + +@dataclass +class Event: + name: str + cat: str + ph: str + ts: int + pid: int + tid: int + args: Dict[str, Any] + + +@dataclass +class TraceEvents: + traceEvents: List[Event] + displayTimeUnit: str = 'ms' + otherData: Dict[str, Any] = None + + +class TraceEventEmitter: + def __init__(self, other_data: Dict[str, Any] = None): + self.events: List[Event] = [] + self.otherData: Dict[str, Any] = other_data if other_data is not None else {} + + self.current_ts = 0 + + def append(self, name: str, duration_us: int, args: Dict[str, Any] = None): + self.events.append( + Event( + name=name, cat='kernel', ph='B', ts=self.current_ts, pid=0, tid=0, args=args if args is not None else {} + ) + ) + self.current_ts += duration_us + self.events.append( + Event( + name=name, cat='kernel', ph='E', ts=self.current_ts, pid=0, tid=0, args=args if args is not None else {} + ) + ) + + def export(self): + return asdict(TraceEvents(traceEvents=self.events, otherData=self.otherData)) + + def save(self, f): + json.dump(self.export(), f) + + +if __name__ == '__main__': + emitter = TraceEventEmitter() + emitter.append('test', 1000) + with open('test.json', 'w') as ff: + emitter.save(ff)