Skip to content

Commit

Permalink
[CUDA GRAPH] Support Cuda Stream in the Wrap Function (apache#21)
Browse files Browse the repository at this point in the history
* issue fix from loading database.

* support cuda stream

* Refactor CUDA kernel launch string formatting in CUDASourceWrapper class
  • Loading branch information
LeiWang1999 authored Apr 23, 2024
1 parent 273d05c commit b2fc5af
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
12 changes: 7 additions & 5 deletions python/bitblas/cache/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _load_operators_from_arch_path(self, arch_path, target):
self._load_operator(config_path, target)

def _load_operator(self, config_path, target):
mapping, config, rt_mod, lib_name = None, None, None, None
mapping, config, rt_mod, src_name, lib_name = None, None, None, None, None
for file in os.listdir(config_path):
full_path = os.path.join(config_path, file)
if file == "mapping.json":
Expand All @@ -143,16 +143,18 @@ def _load_operator(self, config_path, target):
rt_mod = tvm.runtime.load_module(full_path)
elif file == "wrapper_compiled.so":
lib_name = full_path
elif file == "wrapper_source.cu":
src_name = full_path

if mapping and config and rt_mod:
self._instantiate_and_add_operator(mapping, config, rt_mod, lib_name, target)
self._instantiate_and_add_operator(mapping, config, rt_mod, src_name, lib_name, target)

def _instantiate_and_add_operator(self, mapping, config, rt_mod, lib_name, target):
def _instantiate_and_add_operator(self, mapping, config, rt_mod, src_name, lib_name, target):
config_cls = getattr(bitblas, mapping["config_type"])
operator_cls = getattr(bitblas, mapping["operator_type"])
op_inst = operator_cls(
config=config_cls(**config), target=target, enable_tuning=False, from_database=False)
op_inst.update_runtime_module(rt_mod, lib_name=lib_name)
config=config_cls(**config), target=target, enable_tuning=False, from_database=True)
op_inst.update_runtime_module(rt_mod, src_name=src_name, lib_name=lib_name)
self.add(config_cls(**config), op_inst)


Expand Down
6 changes: 5 additions & 1 deletion python/bitblas/ops/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, name, config: OperatorConfig, target: Target = None):
1 # todo(lei): should be analyzed from the prim_func.
)
self.wrapper = None
self.src_name = None
self.lib_name = None
self.lib = None

Expand Down Expand Up @@ -131,6 +132,7 @@ def tvm_callback_cuda_postproc(code, _):
self.arch)
wrapper.compile_lib()
self.wrapper = wrapper
self.src_name = self.wrapper.src_name
self.lib_name = self.wrapper.lib_name
self.lib = self.wrapper.load_lib()
self.lib.init()
Expand Down Expand Up @@ -291,11 +293,13 @@ def __call__(self, *args: Any) -> Any:
def update_func(self, func: PrimFunc):
self.prim_func_mod["main"] = func

def update_runtime_module(self, rt_mod, lib_name=None):
def update_runtime_module(self, rt_mod, src_name=None, lib_name=None):
self.rt_mod = rt_mod
self.time_evaluator = rt_mod.time_evaluator(rt_mod.entry_name, self.arch.device, number=10)
self.function_handle = rt_mod.get_function(rt_mod.entry_name).handle
self.torch_func = to_pytorch_func(rt_mod)
if src_name is not None:
self.src_name = src_name
if lib_name is not None:
self.lib_name = lib_name
self.lib = ctypes.CDLL(lib_name)
Expand Down
16 changes: 10 additions & 6 deletions python/bitblas/wrapper/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ def update_lib_code(self, code: str):
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})

function_args.append({"name": "stream=0", "type": "cudaStream_t"},)
# Format the function arguments for declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])

Expand Down Expand Up @@ -284,8 +285,8 @@ def legalize_c(p):
# Determine the shared memory size, defaulting to 0 if not specified
smem_str = 0 if self.dynamic_smem_buf is None else self.dynamic_smem_buf
# Format the CUDA kernel launch string
call_str = "{}<<<{}, {}, {}>>>({});".format(function_name, grid_str, block_str, smem_str,
call_args)
call_str = "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str,
smem_str, call_args)
# Create the host function wrapper for the CUDA kernel
host_func = """
extern "C" void call({}) {{
Expand Down Expand Up @@ -351,6 +352,8 @@ def create_dispatch_func(self, code, function_informations):
for dyn_sym in dynamic_symbolic_set:
function_args.append({"name": dyn_sym, "type": "int"})

function_args.append({"name": "stream=0", "type": "cudaStream_t"},)

# Format the argument definitions for function declaration
def_args = ", ".join([f"{arg['type']} {arg['name']}" for arg in function_args])

Expand Down Expand Up @@ -401,7 +404,7 @@ def legalize_c(p):
(symbolic,) = list(dynamic_symbolic_set)
range_str = opt_shapes[symbolic]
if last_range == 0:
call_str = "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}>>>({}); \n\t\t}}\n".format(
call_str = "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format(
symbolic,
range_str,
function_name,
Expand All @@ -411,7 +414,7 @@ def legalize_c(p):
call_args,
)
else:
call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}>>>({}); \n\t\t}}\n".format(
call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format(
symbolic,
range_str,
function_name,
Expand All @@ -421,8 +424,9 @@ def legalize_c(p):
call_args,
)
if last_range == num_items - 1:
call_str += ("\t\telse {{\n\t\t\t {}<<<{}, {}, {}>>>({}); \n\t\t}}\n".format(
function_name, grid_str, block_str, smem_str, call_args))
call_str += (
"\t\telse {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format(
function_name, grid_str, block_str, smem_str, call_args))
last_range += 1
_call_str += call_str

Expand Down

0 comments on commit b2fc5af

Please sign in to comment.