Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 10 additions & 14 deletions benchmark/benchmark_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,33 +46,29 @@ def get_configs(M, N, K, with_roller=False):
thread numbers, and other parameters to explore during autotuning.
"""
if with_roller:
from bitblas.base.utils import get_roller_hints_from_func
from bitblas.ops.general_matmul.tirscript import matmul_select_implementation
from bitblas.base.arch import CUDA
from bitblas.base.roller.rasterization import NoRasterization
from tilelang.carver.template import MatmulTemplate
from tilelang.carver.arch import CUDA
from tilelang.carver.roller.rasterization import NoRasterization
arch = CUDA("cuda")
topk = 20

# Simple TIR Compute Expression
ir_module = matmul_select_implementation(
carve_template = MatmulTemplate(
M=M,
N=N,
K=K,
in_dtype="float16",
out_dtype="float16",
accum_dtype="float16",
)
).with_arch(arch)

roller_hints = get_roller_hints_from_func(
ir_module,
arch,
topk,
tensorcore_only=True,
allow_gemv=True,
)
func = carve_template.equivalent_function()
assert func is not None, "Function is None"

roller_hints = carve_template.recommend_hints(topk=topk)

if roller_hints is None:
raise ValueError("No Roller Hints Found for TensorCore Scheduling")

configs = []
for hint in roller_hints:
config = {}
Expand Down
5 changes: 4 additions & 1 deletion tilelang/carver/roller/policy/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ def emit_config(self, topk: int) -> List[Hint]:

self._expand_reduce_axis(td)
for codegen_dicts in self.assign_block_size(td):
results.append(codegen_dicts)
if isinstance(codegen_dicts, dict) and len(codegen_dicts) == 1:
results.append(list(codegen_dicts.values())[0])
else:
results.append(codegen_dicts)
if len(results) >= topk:
break
if len(results) >= topk:
Expand Down
10 changes: 7 additions & 3 deletions tilelang/carver/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],

assert func is not None, "The function should not be None"

roller_hints = None
if tensorcore_only:
try:
tensorized_func, tags = get_tensorized_func_and_tags(
Expand All @@ -53,9 +54,9 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
tags = None
if tags and tensorized_func:
policy = TensorCorePolicy(func=tensorized_func, arch=arch, tags=tags)
return policy.emit_config(topk)
roller_hints = policy.emit_config(topk)
else:
return None
roller_hints = None
else:
policy = DefaultPolicy.from_prim_func(func=func, arch=arch)
tensorized_func = None
Expand All @@ -67,7 +68,10 @@ def get_roller_hints_from_func(func_or_module: Union[tir.PrimFunc, IRModule],
tags = None
if tags and tensorized_func:
policy = TensorCorePolicy.from_prim_func(func=tensorized_func, arch=arch, tags=tags)
return policy.emit_config(topk)
roller_hints = policy.emit_config(topk)
else:
roller_hints = None
return roller_hints


def get_roller_hints_from_output_nodes(
Expand Down
Loading