Skip to content

Commit 1917771

Browse files
authored
[Refactor] Remove BitBLAS Import in Benchmark (#150)
1 parent cf40574 commit 1917771

File tree

2 files changed

+10
-27
lines changed

2 files changed

+10
-27
lines changed

benchmark/benchmark_matmul.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,13 +149,6 @@ def matmul(M, N, K, with_roller):
149149
# - A reference program for correctness verification
150150
# - The "tvm" profiler backend
151151
# - HIP as the compilation target (modify as needed for your hardware)
152-
if with_roller:
153-
# check out bitblas is installed
154-
try:
155-
import bitblas # noqa: F401
156-
except ImportError as e:
157-
raise ImportError(
158-
"BitBlas is not installed. Please install it via 'pip install bitblas'.") from e
159152

160153
@autotune(
161154
configs=get_configs(M, N, K, with_roller),

testing/python/autotune/test_tilelang_autotune.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -49,33 +49,30 @@ def get_configs(M, N, K, with_roller=False):
4949
thread numbers, and other parameters to explore during autotuning.
5050
"""
5151
if with_roller:
52-
from bitblas.base.utils import get_roller_hints_from_func
53-
from bitblas.ops.general_matmul.tirscript import matmul_select_implementation
54-
from bitblas.base.arch import CUDA
55-
from bitblas.base.roller.rasterization import NoRasterization
52+
from tilelang.carver.template import MatmulTemplate
53+
from tilelang.carver.arch import CUDA
54+
from tilelang.carver.roller.rasterization import NoRasterization
5655
arch = CUDA("cuda")
5756
topk = 20
5857

5958
# Simple TIR Compute Expression
60-
ir_module = matmul_select_implementation(
59+
carve_template = MatmulTemplate(
6160
M=M,
6261
N=N,
6362
K=K,
6463
in_dtype="float16",
6564
out_dtype="float16",
6665
accum_dtype="float16",
67-
)
66+
).with_arch(arch)
6867

69-
roller_hints = get_roller_hints_from_func(
70-
ir_module,
71-
arch,
72-
topk,
73-
tensorcore_only=True,
74-
allow_gemv=True,
75-
)
68+
func = carve_template.equivalent_function()
69+
assert func is not None, "Function is None"
70+
71+
roller_hints = carve_template.recommend_hints(topk=topk)
7672

7773
if roller_hints is None:
7874
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
75+
7976
configs = []
8077
for hint in roller_hints:
8178
config = {}
@@ -156,13 +153,6 @@ def matmul(M, N, K, with_roller):
156153
# - A reference program for correctness verification
157154
# - The "tvm" profiler backend
158155
# - HIP as the compilation target (modify as needed for your hardware)
159-
if with_roller:
160-
# check out bitblas is installed
161-
try:
162-
import bitblas # noqa: F401
163-
except ImportError as e:
164-
raise ImportError(
165-
"BitBlas is not installed. Please install it via 'pip install bitblas'.") from e
166156

167157
@autotune(
168158
configs=get_configs(M, N, K, with_roller),

0 commit comments

Comments
 (0)