Skip to content

Commit

Permalink
[TOOLS] Attached hash values to function signature in source.cu (#459)
Browse files Browse the repository at this point in the history
Suffixed IRModules' name with hash code so that we can identify the corresponding function in profiling stage
  • Loading branch information
ZichuWu authored Sep 23, 2024
1 parent 08de3dd commit 9f3e2b9
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 6 deletions.
4 changes: 2 additions & 2 deletions python/hidet/drivers/build_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import os
import json
import shutil
from hashlib import sha256
from typing import List, Optional, Tuple
from tqdm import tqdm

Expand Down Expand Up @@ -145,6 +144,7 @@ def launch(arg: meta.types(param_types)):
ir_module.add_function(get_output_shape.name, get_output_shape)
ir_module.object_files.extend([os.path.join(object_path, 'lib.o') for object_path in objects_path_list])
task_ir_module = ir_module
task_ir_module.task = task

# add assertions to the launch function
if len(task.assertions) > 0:
Expand Down Expand Up @@ -243,7 +243,7 @@ def build_task(task: Task, target='cuda', load=True) -> Optional[CompiledTask]:
else:
# check on-disk cache
config_str = f'{target}_space_{space_level}'
task_hash = sha256(task_string.encode()).hexdigest()[:16]
task_hash = task.calculate_hash()
task_dir = os.path.join(op_cache_dir, config_str, task.name, task_hash)
lib_path = os.path.join(task_dir, 'lib.so')
version_path = os.path.join(task_dir, 'version.txt')
Expand Down
15 changes: 12 additions & 3 deletions python/hidet/graph/ops/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def launch(x: dtype[shape], y: dtype[shape]):
attrs.func_kind = 'public'
_all_reduce(x, y, size, dtype, str_to_nccl_op(self.op), self.comm_id)

return [script_module.ir_module()]
ir_module: IRModule = script_module.ir_module()
ir_module.task = self

return [ir_module]


class AllReduceOp(Operator):
Expand Down Expand Up @@ -87,7 +90,10 @@ def launch(x: dtype[shape], y: dtype[out_shape]):
attrs.func_kind = 'public'
_all_gather(x, y, size, dtype, self.comm_id)

return [script_module.ir_module()]
ir_module: IRModule = script_module.ir_module()
ir_module.task = self

return [ir_module]


class AllGatherOp(Operator):
Expand Down Expand Up @@ -132,7 +138,10 @@ def launch(x: dtype[shape], y: dtype[shape[1:]]):
attrs.func_kind = 'public'
_reduce_scatter(x, y, size, dtype, str_to_nccl_op(self.op), self.comm_id)

return [script_module.ir_module()]
ir_module: IRModule = script_module.ir_module()
ir_module.task = self

return [ir_module]


class ReduceScatterOp(Operator):
Expand Down
3 changes: 3 additions & 0 deletions python/hidet/graph/ops/fusion/fused_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModu
if hasattr(anchor_module, '_tuning_kwargs'):
setattr(fused_module, '_tuning_kwargs', getattr(anchor_module, '_tuning_kwargs'))

for fused_module in fused_modules:
fused_module.task = self

return fused_modules


Expand Down
5 changes: 4 additions & 1 deletion python/hidet/graph/ops/transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def launch(x: dtype[shape], y: dtype[shape]):

memcpy_async(y, x, count=nbytes, kind=kind)

return [script_module.ir_module()]
ir_module: IRModule = script_module.ir_module()
ir_module.task = self

return [ir_module]


class TransferOp(Operator):
Expand Down
3 changes: 3 additions & 0 deletions python/hidet/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
linking_dirs: List[str] = None,
linking_libs: List[str] = None,
object_files: List[str] = None,
task: 'Task' = None,
):
# the functions defined in this module
self.functions: Dict[str, Function] = functions if functions else {}
Expand All @@ -51,6 +52,7 @@ def __init__(
self.linking_dirs: List[str] = linking_dirs if linking_dirs else [] # -I flags
self.linking_libs: List[str] = linking_libs if linking_libs else [] # -l flags
self.object_files: List[str] = object_files if object_files else [] # .o files
self.task = task

assert all(isinstance(func, Function) for func in self.functions.values()) and all(
isinstance(var, Var) for var in self.global_vars.values()
Expand Down Expand Up @@ -84,6 +86,7 @@ def copy(self):
linking_dirs=self.linking_dirs.copy(),
linking_libs=self.linking_libs.copy(),
object_files=self.object_files.copy(),
task=self.task,
)

def reset_funcs(self, functions: Dict[str, Function] = None, global_vars: Dict[str, Var] = None):
Expand Down
7 changes: 7 additions & 0 deletions python/hidet/ir/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import enum
import pickle
from hashlib import sha256
from hidet.ir.node import Node
from hidet.ir.type import FuncType, VoidType
from hidet.ir.expr import Expr, Var, SymbolVar, var, is_constant
Expand Down Expand Up @@ -298,6 +299,9 @@ def implement(self, target: Union[Target, str], working_dir: str) -> List[IRModu
'Expect the `implement` method to return an IRModule or List[IRModule], got {}'.format(ir_modules)
)

for ir_module in ir_modules:
ir_module.task = self

return ir_modules

def implement_cuda(self, working_dir: str) -> Union[IRModule, List[IRModule]]:
Expand Down Expand Up @@ -334,6 +338,9 @@ def load(fname: str) -> Task:
with open(fname, 'rb') as f:
return pickle.load(f)

def calculate_hash(self, len: int = 16) -> str:
return sha256(str(self).encode()).hexdigest()[:len]


def save_task(task: Task, fname: str):
task.save(fname)
Expand Down
2 changes: 2 additions & 0 deletions python/hidet/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .base import Pass, FunctionPass, SequencePass, RepeatFunctionPass, PassContext
from .instruments import PassInstrument, SaveIRInstrument, ProfileInstrument

from .attach_hash_to_signature import attach_hash_to_signature
from .unify_global_objects import unify_global_objects_pass
from .flatten_tensor_slice import flatten_tensor_slice_pass
from .flatten_tensor_index import flatten_tensor_index_pass
Expand Down Expand Up @@ -71,6 +72,7 @@ def lower(ir_module: IRModule) -> IRModule:

transforms = [
# necessary passes
attach_hash_to_signature(),
unify_global_objects_pass(),
generate_launch_func_pass(),
flatten_tensor_slice_pass(),
Expand Down
69 changes: 69 additions & 0 deletions python/hidet/transforms/attach_hash_to_signature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
from hidet.ir.module import IRModule
from hidet.ir.functors import IRRewriter
from hidet.ir.stmt import LaunchKernelStmt
from hidet.transforms import Pass


class AttachHashToSignatureRewriter(IRRewriter):
def __init__(self, old_name: str, new_name: str, use_memo=True):
self.old_name = old_name
self.new_name = new_name
super().__init__(use_memo)

def visit_LaunchKernelStmt(self, stmt: LaunchKernelStmt):
if stmt.func_var.name == self.old_name:
stmt.func_var.name = self.new_name
return super().visit_LaunchKernelStmt(stmt)


class AttachHashToSignature(Pass):
def process_module(self, ir_module: IRModule) -> IRModule:
if ir_module.task is None:
logging.warning("A IRModule without task is detected. Designated function hash cannot be append")
return ir_module

modify_name = {}

for func in ir_module.functions.values():
if func.kind in ['cuda_kernel', 'cpu_kernel']:
task_hash = ir_module.task.calculate_hash(4)
old_name = func.name
new_name = func.name + f'_{task_hash}'
modify_name[old_name] = new_name

if not modify_name:
return ir_module

new_ir_module = ir_module.copy()

for old_name, new_name in modify_name.items():
new_ir_module.functions[new_name] = new_ir_module.functions.pop(old_name)
new_ir_module.functions[new_name].name = new_name

if old_name in ir_module.global_vars.keys():
new_ir_module.global_vars[new_name] = new_ir_module.global_vars.pop(old_name)
new_ir_module.global_vars[new_name].name = new_name

if any(func.name.startswith('launch') for func in new_ir_module.functions.values() if func.kind == 'public'):
for old_name, new_name in modify_name.items():
rewriter = AttachHashToSignatureRewriter(old_name, new_name)
new_ir_module = rewriter.visit(new_ir_module)

return new_ir_module


def attach_hash_to_signature() -> Pass:
return AttachHashToSignature()
2 changes: 2 additions & 0 deletions tests/ir/parser/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest
from hidet.ir.tools.ir_dumper import astext2, parse
from hidet.ir.expr import symbol_var
from hidet.transforms.attach_hash_to_signature import attach_hash_to_signature
from hidet.transforms.unify_global_objects import unify_global_objects_pass
from hidet.transforms.flatten_tensor_slice import flatten_tensor_slice_pass
from hidet.transforms.flatten_tensor_index import flatten_tensor_index_pass
Expand Down Expand Up @@ -87,6 +88,7 @@ def get_attn_task():
def generate_ir_modules():
transforms = [
lambda x: x,
attach_hash_to_signature(),
unify_global_objects_pass(),
generate_launch_func_pass(),
flatten_tensor_slice_pass(),
Expand Down

0 comments on commit 9f3e2b9

Please sign in to comment.