From bcb57b2631dae5c727da6caa7e50ce9fa789c4ea Mon Sep 17 00:00:00 2001 From: Max Hu Date: Mon, 4 Mar 2024 16:34:07 -0500 Subject: [PATCH] [Compilation] Check uniqueness of function names after regrouping --- python/hidet/drivers/build_module.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/hidet/drivers/build_module.py b/python/hidet/drivers/build_module.py index c692f5e47..0f07283f9 100644 --- a/python/hidet/drivers/build_module.py +++ b/python/hidet/drivers/build_module.py @@ -207,6 +207,23 @@ def regroup_modules(modules, size): else: return modules + # check if regrouped IRModules have unique function names + def check_function_singular(module_list: Union[Sequence[IRModule], Sequence[Sequence[IRModule]]]) -> bool: + if len(module_list) == 0 or isinstance(module_list[0], IRModule): + return True + name_set = set() + for modules in module_list: + for module in modules: + namespace_str = module.namespace + function_name_list = list(module.extern_functions.keys()) + list(module.functions.keys()) + for func_name in function_name_list: + func_str = namespace_str + '::' + func_name + if func_str in name_set: + return False + else: + name_set.add(func_str) + return True + # calculate the number of workers cpu_count = os.cpu_count() if hidet.option.compile_server.enabled(): @@ -229,6 +246,7 @@ def regroup_modules(modules, size): per_worker_jobs = 1 if len(ir_modules) < num_workers else len(ir_modules) // num_workers ir_modules_list = regroup_modules(ir_modules, per_worker_jobs) + assert check_function_singular(ir_modules_list), 'duplicate function names detected after regrouping candidates for batch compilation' jobs = [ (ir_modules, output_dir) for ir_modules, output_dir in zip(ir_modules_list, output_dirs[: len(ir_modules_list)])