@@ -49,33 +49,30 @@ def get_configs(M, N, K, with_roller=False):
4949 thread numbers, and other parameters to explore during autotuning.
5050 """
5151 if with_roller :
52- from bitblas .base .utils import get_roller_hints_from_func
53- from bitblas .ops .general_matmul .tirscript import matmul_select_implementation
54- from bitblas .base .arch import CUDA
55- from bitblas .base .roller .rasterization import NoRasterization
52+ from tilelang .carver .template import MatmulTemplate
53+ from tilelang .carver .arch import CUDA
54+ from tilelang .carver .roller .rasterization import NoRasterization
5655 arch = CUDA ("cuda" )
5756 topk = 20
5857
5958 # Simple TIR Compute Expression
60- ir_module = matmul_select_implementation (
59+ carve_template = MatmulTemplate (
6160 M = M ,
6261 N = N ,
6362 K = K ,
6463 in_dtype = "float16" ,
6564 out_dtype = "float16" ,
6665 accum_dtype = "float16" ,
67- )
66+ ). with_arch ( arch )
6867
69- roller_hints = get_roller_hints_from_func (
70- ir_module ,
71- arch ,
72- topk ,
73- tensorcore_only = True ,
74- allow_gemv = True ,
75- )
68+ func = carve_template .equivalent_function ()
69+ assert func is not None , "Function is None"
70+
71+ roller_hints = carve_template .recommend_hints (topk = topk )
7672
7773 if roller_hints is None :
7874 raise ValueError ("No Roller Hints Found for TensorCore Scheduling" )
75+
7976 configs = []
8077 for hint in roller_hints :
8178 config = {}
@@ -156,13 +153,6 @@ def matmul(M, N, K, with_roller):
156153 # - A reference program for correctness verification
157154 # - The "tvm" profiler backend
158155 # - HIP as the compilation target (modify as needed for your hardware)
159- if with_roller :
160- # check out bitblas is installed
161- try :
162- import bitblas # noqa: F401
163- except ImportError as e :
164- raise ImportError (
165- "BitBlas is not installed. Please install it via 'pip install bitblas'." ) from e
166156
167157 @autotune (
168158 configs = get_configs (M , N , K , with_roller ),
0 commit comments