Skip to content

Commit

Permalink
[CPU][Scheduler] Use mutli-threads for autl-scheduler (hidet-org#341)
Browse files Browse the repository at this point in the history
Also add the operator trace.
  • Loading branch information
yaoyaoding authored Aug 3, 2023
1 parent 56197c9 commit e9bfa99
Show file tree
Hide file tree
Showing 8 changed files with 119 additions and 19 deletions.
9 changes: 9 additions & 0 deletions python/hidet/drivers/build_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,17 @@ def get_signature(t: TensorNode, device: str) -> TensorSignature:
device=device, dtype=t.type.dtype.name, shape=[int(v) if is_constant(v) else str(v) for v in t.shape]
)

# extract the task name
from hidet.graph.ops.fusion.fused_operator import FusedTask

if isinstance(task, FusedTask):
task_name = 'fused_{}'.format(task.attrs['fused_ops'].replace(' ', '_'))
else:
task_name = task.name

# generate meta data
meta = TaskMetaData(
name=task_name,
symbols=[v.name for v in task.symbols],
inputs=[get_signature(t, input_device) for t in task.inputs],
outputs=[get_signature(t, output_device) for t in task.outputs],
Expand Down
6 changes: 3 additions & 3 deletions python/hidet/ir/builders/stmt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union, Optional, Sequence
from typing import Union, Sequence

from hidet.ir.stmt import Stmt, ForStmt, IfStmt, EvaluateStmt, SeqStmt, LetStmt, ForMappingStmt, ForStmtAttr
from hidet.ir.expr import Expr, Var, var, convert
Expand Down Expand Up @@ -65,10 +65,10 @@ def lets(self, bind_vars: Sequence[Union[str, Var]], values: Sequence[Union[int,
seq_let_stmt = LetStmt(bind_vars, bind_values, body=1)
return StmtScope(self, stmts=seq_let_stmt, ret=bind_vars)

def for_loop(self, v: Union[str, Var], extent: Union[int, Expr], unroll: Optional[bool] = None) -> StmtScope:
def for_loop(self, v: Union[str, Var], extent: Union[int, Expr], attr: str = '.') -> StmtScope:
if isinstance(v, str):
v = var(v)
return StmtScope(self, stmts=ForStmt(v, extent, attr=ForStmtAttr(unroll)), ret=v)
return StmtScope(self, stmts=ForStmt(v, extent, attr=ForStmtAttr.parse(attr, num_loops=1)[0]), ret=v)

def if_then(self, cond: Union[bool, Expr]) -> StmtScope:
return StmtScope(self, stmts=[IfStmt(cond)], ret=None)
Expand Down
4 changes: 2 additions & 2 deletions python/hidet/ir/primitives/cuda/mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ def _print_segment(mapping: TaskMapping, dtype: DataType, addr: Expr, worker_id:
if msg:
with sb.if_then(worker_id == 0):
sb += printf(f'{msg}\\n')
with sb.for_loop('i', mapping.task_shape[0], unroll=False) as i:
with sb.for_loop('j', mapping.task_shape[1], unroll=False) as j:
with sb.for_loop('i', mapping.task_shape[0]) as i:
with sb.for_loop('j', mapping.task_shape[1]) as j:
p = var('p', int32)
sb += DeclareStmt(p, int32(0))
with sb.for_mapping(['ii', 'jj'], mapping, worker_id) as (ii, jj):
Expand Down
25 changes: 12 additions & 13 deletions python/hidet/ir/schedulers/cpu/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,16 @@

from hidet.ir.builders import FunctionBuilder
from hidet.ir.compute import TensorNode, GridCompute
from hidet.ir.expr import Var, convert, call
from hidet.ir.expr import Var, call
from hidet.ir.tools import rewrite
from hidet.ir.stmt import Stmt, BufferStoreStmt, EvaluateStmt
from hidet.ir.schedulers.base import AutoScheduler, ComputeExprLower
from hidet.ir.mapping import row_spatial
from hidet.utils.py import prod


class CpuAutoScheduler(AutoScheduler):
def schedule_grid_compute(self, node: GridCompute, tensor_map: Dict[TensorNode, Var]) -> Stmt:
# pylint: disable=too-many-locals, import-outside-toplevel, unnecessary-comprehension
from hidet.ir.mapping import row_repeat, TaskMapping

params, param_map, call_args = self.grid_compute_params_and_args(node, tensor_map)

if self.task is not None:
Expand All @@ -35,16 +34,16 @@ def schedule_grid_compute(self, node: GridCompute, tensor_map: Dict[TensorNode,
# set function parameters
fb.extend_params(params)

mapping: TaskMapping = row_repeat(*node.shape)
iter_names = [f'i{i}' for i in range(len(node.shape))]
with fb.for_mapping(iter_names, mapping, convert(0)) as task_index:
out_param: Var = param_map[node]
compute_lower = ComputeExprLower(node.value, param_map=param_map)
stmts, value = compute_lower.lower()
rmap = {axis: axis_value for axis, axis_value in zip(node.axes, task_index)}
stmts, value = rewrite([stmts, value], rmap)
fb += stmts
fb += BufferStoreStmt(out_param, task_index, value)
with fb.for_loop('w', extent=prod(node.shape), attr='p') as w:
with fb.for_mapping(iter_names, row_spatial(*node.shape), worker=w) as task_index:
out_param: Var = param_map[node]
compute_lower = ComputeExprLower(node.value, param_map=param_map)
stmts, value = compute_lower.lower()
rmap = {axis: axis_value for axis, axis_value in zip(node.axes, task_index)}
stmts, value = rewrite([stmts, value], rmap)
fb += stmts
fb += BufferStoreStmt(out_param, task_index, value)
func = fb.get()
func_var = self.add_function(func)

Expand Down
30 changes: 29 additions & 1 deletion python/hidet/runtime/compiled_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@
from hidet.runtime.compiled_task import CompiledTask, TensorSignature, _check_inputs
from hidet.runtime.storage import Storage
from hidet.ffi import runtime_api
from hidet.utils import prod
from hidet.utils.py import prod, median
from hidet.utils.trace_utils import TraceEventEmitter

ModelExecutionHook = Callable[[int, List['Tensor'], List['Tensor']], None]

Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(
self.constant_outputs: List[Union[None, Tensor]] = []

# runtime state
self.working_dir: str = hidet.utils.cache_file('graphs', self.meta.graph_hash)
self.dispatch_table_path = hidet.utils.cache_file('graphs', self.meta.graph_hash, 'dispatch_table.txt')
self.dispatch_table: Dict[Tuple[int, ...], Array] = {}
self.cuda_workspace: Optional[Storage] = None
Expand Down Expand Up @@ -258,23 +260,49 @@ def _run_slow_path(self, inputs, symbol_dims: Tuple[int, ...]):
index2tensor[exe.inputs_index[i]] = inputs[i]
for i in range(len(self.weights)):
index2tensor[exe.weights_index[i]] = self.weights[i]

best_candidates = [-1 for _ in range(len(self.compiled_tasks))]
trace_emitter = TraceEventEmitter({'graph': self.graph_string})
for inst in exe.instructions:
# prepare inputs and kernel
node_inputs = [index2tensor[i] for i in inst.inputs]
node_kernel: CompiledTask = self.compiled_tasks[inst.task_idx]

# run the kernel
node_outputs = node_kernel.run_async(node_inputs)

# record outputs
for i, output_index in enumerate(inst.outputs):
index2tensor[output_index] = node_outputs[i]

# record best candidate for this kernel
best_candidates[inst.task_idx] = node_kernel.pick_best_candidate(node_inputs, node_outputs)

# record trace events
trace_emitter.append(
name=node_kernel.meta_data.name,
duration_us=int(median(node_kernel.profile(*node_inputs, *node_outputs)) * 1000),
args={
'name': node_kernel.meta_data.name,
'inputs': ['{}{}'.format(x.dtype, x.shape) for x in node_kernel.meta_data.inputs],
'outputs': ['{}{}'.format(x.dtype, x.shape) for x in node_kernel.meta_data.outputs],
},
)

# free tensors that are no longer needed
for idx in inst.free:
del index2tensor[idx]

outputs = [index2tensor[i] for i in exe.outputs_index]

# update the dispatch table
self._update_symbol_table(symbol_dims, best_candidates)

# save the trace
trace_filename = 'trace{}.json'.format('_'.join(str(x) for x in symbol_dims))
with open(os.path.join(self.working_dir, trace_filename), 'w') as f:
trace_emitter.save(f)

return outputs

def run_async(self, inputs):
Expand Down
1 change: 1 addition & 0 deletions python/hidet/runtime/compiled_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class TensorSignature:

@dataclass
class TaskMetaData:
name: str
symbols: List[str]
inputs: List[TensorSignature]
outputs: List[TensorSignature]
Expand Down
8 changes: 8 additions & 0 deletions python/hidet/utils/py.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@ def prod(seq: Iterable):
return c


def median(seq: Iterable):
seq = list(seq)
if len(seq) == 0:
return None
else:
return sorted(seq)[len(seq) // 2]


def clip(
x: Union[int, float], low: Optional[Union[int, float]], high: Optional[Union[int, float]]
) -> Union[int, float]:
Expand Down
55 changes: 55 additions & 0 deletions python/hidet/utils/trace_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any, Dict, List
from dataclasses import dataclass, asdict
import json


@dataclass
class Event:
name: str
cat: str
ph: str
ts: int
pid: int
tid: int
args: Dict[str, Any]


@dataclass
class TraceEvents:
traceEvents: List[Event]
displayTimeUnit: str = 'ms'
otherData: Dict[str, Any] = None


class TraceEventEmitter:
def __init__(self, other_data: Dict[str, Any] = None):
self.events: List[Event] = []
self.otherData: Dict[str, Any] = other_data if other_data is not None else {}

self.current_ts = 0

def append(self, name: str, duration_us: int, args: Dict[str, Any] = None):
self.events.append(
Event(
name=name, cat='kernel', ph='B', ts=self.current_ts, pid=0, tid=0, args=args if args is not None else {}
)
)
self.current_ts += duration_us
self.events.append(
Event(
name=name, cat='kernel', ph='E', ts=self.current_ts, pid=0, tid=0, args=args if args is not None else {}
)
)

def export(self):
return asdict(TraceEvents(traceEvents=self.events, otherData=self.otherData))

def save(self, f):
json.dump(self.export(), f)


if __name__ == '__main__':
emitter = TraceEventEmitter()
emitter.append('test', 1000)
with open('test.json', 'w') as ff:
emitter.save(ff)

0 comments on commit e9bfa99

Please sign in to comment.