From b2fc5afedb065b49c775ff834a70502335832a61 Mon Sep 17 00:00:00 2001 From: Lei Wang <34334180+LeiWang1999@users.noreply.github.com> Date: Wed, 24 Apr 2024 01:51:21 +0800 Subject: [PATCH] [CUDA GRAPH] Support Cuda Stream in the Wrap Function (#21) * issue fix from loading database. * support cuda stream * Refactor CUDA kernel launch string formatting in CUDASourceWrapper class --- python/bitblas/cache/operator.py | 12 +++++++----- python/bitblas/ops/operator.py | 6 +++++- python/bitblas/wrapper/general.py | 16 ++++++++++------ 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/python/bitblas/cache/operator.py b/python/bitblas/cache/operator.py index d3638fb98bd9..75c67662d294 100644 --- a/python/bitblas/cache/operator.py +++ b/python/bitblas/cache/operator.py @@ -130,7 +130,7 @@ def _load_operators_from_arch_path(self, arch_path, target): self._load_operator(config_path, target) def _load_operator(self, config_path, target): - mapping, config, rt_mod, lib_name = None, None, None, None + mapping, config, rt_mod, src_name, lib_name = None, None, None, None, None for file in os.listdir(config_path): full_path = os.path.join(config_path, file) if file == "mapping.json": @@ -143,16 +143,18 @@ def _load_operator(self, config_path, target): rt_mod = tvm.runtime.load_module(full_path) elif file == "wrapper_compiled.so": lib_name = full_path + elif file == "wrapper_source.cu": + src_name = full_path if mapping and config and rt_mod: - self._instantiate_and_add_operator(mapping, config, rt_mod, lib_name, target) + self._instantiate_and_add_operator(mapping, config, rt_mod, src_name, lib_name, target) - def _instantiate_and_add_operator(self, mapping, config, rt_mod, lib_name, target): + def _instantiate_and_add_operator(self, mapping, config, rt_mod, src_name, lib_name, target): config_cls = getattr(bitblas, mapping["config_type"]) operator_cls = getattr(bitblas, mapping["operator_type"]) op_inst = operator_cls( - config=config_cls(**config), target=target, enable_tuning=False, from_database=False) - op_inst.update_runtime_module(rt_mod, lib_name=lib_name) + config=config_cls(**config), target=target, enable_tuning=False, from_database=True) + op_inst.update_runtime_module(rt_mod, src_name=src_name, lib_name=lib_name) self.add(config_cls(**config), op_inst) diff --git a/python/bitblas/ops/operator.py b/python/bitblas/ops/operator.py index 9585fe4998e6..0290e2e28480 100644 --- a/python/bitblas/ops/operator.py +++ b/python/bitblas/ops/operator.py @@ -58,6 +58,7 @@ def __init__(self, name, config: OperatorConfig, target: Target = None): 1 # todo(lei): should be analyzed from the prim_func. ) self.wrapper = None + self.src_name = None self.lib_name = None self.lib = None @@ -131,6 +132,7 @@ def tvm_callback_cuda_postproc(code, _): self.arch) wrapper.compile_lib() self.wrapper = wrapper + self.src_name = self.wrapper.src_name self.lib_name = self.wrapper.lib_name self.lib = self.wrapper.load_lib() self.lib.init() @@ -291,11 +293,13 @@ def __call__(self, *args: Any) -> Any: def update_func(self, func: PrimFunc): self.prim_func_mod["main"] = func - def update_runtime_module(self, rt_mod, lib_name=None): + def update_runtime_module(self, rt_mod, src_name=None, lib_name=None): self.rt_mod = rt_mod self.time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, self.arch.device, number=10) self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle self.torch_func = to_pytorch_func(rt_mod) + if src_name is not None: + self.src_name = src_name if lib_name is not None: self.lib_name = lib_name self.lib = ctypes.CDLL(lib_name) diff --git a/python/bitblas/wrapper/general.py b/python/bitblas/wrapper/general.py index cca40cfb9286..f97405716a22 100644 --- a/python/bitblas/wrapper/general.py +++ b/python/bitblas/wrapper/general.py @@ -247,6 +247,7 @@ def update_lib_code(self, code: str): for dyn_sym in dynamic_symbolic_set: function_args.append({"name": dyn_sym, "type": "int"}) + function_args.append({"name": "stream=0", "type": "cudaStream_t"},) # Format the function arguments for declaration def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) @@ -284,8 +285,8 @@ def legalize_c(p): # Determine the shared memory size, defaulting to 0 if not specified smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf # Format the CUDA kernel launch string - call_str = "{}<<<{}, {}, {}>>>({});".format(function_name, grid_str, block_str, smem_str, - call_args) + call_str = "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str, + smem_str, call_args) # Create the host function wrapper for the CUDA kernel host_func = """ extern "C" void call({}) {{ @@ -351,6 +352,8 @@ def create_dispatch_func(self, code, function_informations): for dyn_sym in dynamic_symbolic_set: function_args.append({"name": dyn_sym, "type": "int"}) + function_args.append({"name": "stream=0", "type": "cudaStream_t"},) + # Format the argument definitions for function declaration def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args]) @@ -401,7 +404,7 @@ def legalize_c(p): (symbolic,) = list(dynamic_symbolic_set) range_str = opt_shapes[symbolic] if last_range == 0: - call_str = "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}>>>({}); \n\t\t}}\n".format( + call_str = "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( symbolic, range_str, function_name, @@ -411,7 +414,7 @@ def legalize_c(p): call_args, ) else: - call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}>>>({}); \n\t\t}}\n".format( + call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( symbolic, range_str, function_name, @@ -421,8 +424,9 @@ def legalize_c(p): call_args, ) if last_range == num_items - 1: - call_str += ("\t\telse {{\n\t\t\t {}<<<{}, {}, {}>>>({}); \n\t\t}}\n".format( - function_name, grid_str, block_str, smem_str, call_args)) + call_str += ( + "\t\telse {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format( + function_name, grid_str, block_str, smem_str, call_args)) last_range += 1 _call_str += call_str