diff --git a/benchmark/benchmark_matmul.py b/benchmark/benchmark_matmul.py index 73925f7a5..022315fbf 100644 --- a/benchmark/benchmark_matmul.py +++ b/benchmark/benchmark_matmul.py @@ -46,33 +46,29 @@ def get_configs(M, N, K, with_roller=False): thread numbers, and other parameters to explore during autotuning. """ if with_roller: - from bitblas.base.utils import get_roller_hints_from_func - from bitblas.ops.general_matmul.tirscript import matmul_select_implementation - from bitblas.base.arch import CUDA - from bitblas.base.roller.rasterization import NoRasterization + from tilelang.carver.template import MatmulTemplate + from tilelang.carver.arch import CUDA + from tilelang.carver.roller.rasterization import NoRasterization arch = CUDA("cuda") topk = 20 - # Simple TIR Compute Expression - ir_module = matmul_select_implementation( + carve_template = MatmulTemplate( M=M, N=N, K=K, in_dtype="float16", out_dtype="float16", accum_dtype="float16", - ) + ).with_arch(arch) - roller_hints = get_roller_hints_from_func( - ir_module, - arch, - topk, - tensorcore_only=True, - allow_gemv=True, - ) + func = carve_template.equivalent_function() + assert func is not None, "Function is None" + + roller_hints = carve_template.recommend_hints(topk=topk) if roller_hints is None: raise ValueError("No Roller Hints Found for TensorCore Scheduling") + configs = [] for hint in roller_hints: config = {} diff --git a/tilelang/carver/roller/policy/default.py b/tilelang/carver/roller/policy/default.py index daaa1cf7a..71f8811a1 100644 --- a/tilelang/carver/roller/policy/default.py +++ b/tilelang/carver/roller/policy/default.py @@ -94,7 +94,10 @@ def emit_config(self, topk: int) -> List[Hint]: self._expand_reduce_axis(td) for codegen_dicts in self.assign_block_size(td): - results.append(codegen_dicts) + if isinstance(codegen_dicts, dict) and len(codegen_dicts) == 1: + results.append(list(codegen_dicts.values())[0]) + else: + results.append(codegen_dicts) if len(results) >= topk: break if len(results) >= topk: diff --git a/tilelang/carver/utils.py b/tilelang/carver/utils.py index 7bc01eaeb..60ebcce2f 100644 --- a/tilelang/carver/utils.py +++ b/tilelang/carver/utils.py @@ -44,6 +44,7 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule], assert func is not None, "The function should not be None" + roller_hints = None if tensorcore_only: try: tensorized_func, tags = get_tensorized_func_and_tags( @@ -53,9 +54,9 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule], tags = None if tags and tensorized_func: policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags) - return policy.emit_config(topk) + roller_hints = policy.emit_config(topk) else: - return None + roller_hints = None else: policy = DefaultPolicy.from_prim_func(func=func, arch=arch) tensorized_func = None @@ -67,7 +68,10 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule], tags = None if tags and tensorized_func: policy = TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags) - return policy.emit_config(topk) + roller_hints = policy.emit_config(topk) + else: + roller_hints = None + return roller_hints def get_roller_hints_from_output_nodes(