Skip to content

Commit 151f912

Browse files
authored
[Dev] Implement test case for tilelang transformations (#53)
* implement jit test case * [Dev] implement auto tune test case for matrix multiplication * Implement test for legalize memory access and vectorized loop * lint fix
1 parent ba56e06 commit 151f912

File tree

4 files changed

+512
-2
lines changed

4 files changed

+512
-2
lines changed
Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import itertools
5+
import logging
6+
7+
import tilelang as tl
8+
import tilelang.testing
9+
import tilelang.language as T
10+
from tilelang.autotuner import autotune, jit
11+
12+
# Configure logger
13+
logger = logging.getLogger(__name__)
14+
logger.setLevel(logging.DEBUG)
15+
16+
17+
def ref_program(A, B):
18+
"""
19+
A reference matrix multiplication program, used to compare performance.
20+
21+
Parameters
22+
----------
23+
A : numpy.ndarray
24+
The matrix with shape (M, K).
25+
B : numpy.ndarray
26+
The matrix with shape (N, K).
27+
28+
Returns
29+
-------
30+
np.ndarray
31+
The result of A @ B.T, shape (M, N).
32+
"""
33+
return A @ B.T
34+
35+
36+
def get_configs(M, N, K, with_roller=False):
37+
"""
38+
Generate a list of configuration dictionaries that will be used for tuning.
39+
40+
Parameters
41+
----------
42+
with_roller : bool
43+
Whether to enable bitblas roller to deduce search spaces
44+
45+
Returns
46+
-------
47+
list of dict
48+
Each configuration dict includes various block sizes, pipeline stages,
49+
thread numbers, and other parameters to explore during autotuning.
50+
"""
51+
if with_roller:
52+
from bitblas.base.utils import get_roller_hints_from_func
53+
from bitblas.ops.general_matmul.tirscript import matmul_select_implementation
54+
from bitblas.base.arch import CUDA
55+
from bitblas.base.roller.rasterization import NoRasterization
56+
arch = CUDA("cuda")
57+
topk = 20
58+
59+
# Simple TIR Compute Expression
60+
ir_module = matmul_select_implementation(
61+
M=M,
62+
N=N,
63+
K=K,
64+
in_dtype="float16",
65+
out_dtype="float16",
66+
accum_dtype="float16",
67+
)
68+
69+
roller_hints = get_roller_hints_from_func(
70+
ir_module,
71+
arch,
72+
topk,
73+
tensorcore_only=True,
74+
allow_gemv=True,
75+
)
76+
77+
if roller_hints is None:
78+
raise ValueError("No Roller Hints Found for TensorCore Scheduling")
79+
configs = []
80+
for hint in roller_hints:
81+
config = {}
82+
block_m, block_n = hint.block
83+
warp_m, warp_n = hint.warp
84+
config["block_M"] = block_m
85+
config["block_N"] = block_n
86+
config["block_K"] = hint.rstep[0]
87+
config["num_stages"] = 0
88+
config["thread_num"] = (block_m * block_n) // (warp_m * warp_n) * 32
89+
config["enable_rasteration"] = hint.rasterization_plan is not NoRasterization
90+
configs.append(config)
91+
for config in configs:
92+
print(config)
93+
else:
94+
95+
block_M = [64]
96+
block_N = [64]
97+
block_K = [32]
98+
num_stages = [0, 1]
99+
thread_num = [128]
100+
enable_rasterization = [False]
101+
102+
_configs = list(
103+
itertools.product(
104+
block_M,
105+
block_N,
106+
block_K,
107+
num_stages,
108+
thread_num,
109+
enable_rasterization,
110+
))
111+
112+
configs = [
113+
{
114+
"block_M": c[0],
115+
"block_N": c[1],
116+
"block_K": c[2],
117+
"num_stages": c[3],
118+
"thread_num": c[4],
119+
"enable_rasteration": c[5], # keep param name for backward-compat
120+
} for c in _configs
121+
]
122+
return configs
123+
124+
125+
def matmul(M, N, K, with_roller):
126+
"""
127+
Create an autotuned matrix multiplication kernel for matrices of shape:
128+
- A: (M, K)
129+
- B: (N, K)
130+
- C: (M, N)
131+
132+
Parameters
133+
----------
134+
M : int
135+
The dimension M of the matrix multiplication.
136+
N : int
137+
The dimension N of the matrix multiplication.
138+
K : int
139+
The dimension K of the matrix multiplication.
140+
141+
Returns
142+
-------
143+
(best_latency, best_config, ref_latency)
144+
best_latency : float
145+
The best latency found among the tuned configurations.
146+
best_config : dict
147+
The parameter configuration that yielded best_latency.
148+
ref_latency : float
149+
The baseline latency of the reference program (for computing speedup).
150+
"""
151+
152+
# Decorate the kernel with autotune & jit, specifying:
153+
# - Tuning config list
154+
# - Profiling keys
155+
# - Warmup and repetition counts for better measurement
156+
# - A reference program for correctness verification
157+
# - The "tvm" profiler backend
158+
# - HIP as the compilation target (modify as needed for your hardware)
159+
if with_roller:
160+
# check out bitblas is installed
161+
try:
162+
import bitblas # noqa: F401
163+
except ImportError as e:
164+
raise ImportError(
165+
"BitBlas is not installed. Please install it via 'pip install bitblas'.") from e
166+
167+
@autotune(
168+
configs=get_configs(M, N, K, with_roller),
169+
keys=[
170+
"block_M",
171+
"block_N",
172+
"block_K",
173+
"num_stages",
174+
"thread_num",
175+
"enable_rasteration",
176+
],
177+
warmup=3,
178+
rep=5,
179+
)
180+
@jit(
181+
out_idx=[2],
182+
supply_type=tl.TensorSupplyType.Integer,
183+
ref_prog=ref_program,
184+
skip_check=True,
185+
profiler="auto",
186+
target="auto",
187+
)
188+
def kernel(
189+
block_M=None,
190+
block_N=None,
191+
block_K=None,
192+
num_stages=None,
193+
thread_num=None,
194+
enable_rasteration=None,
195+
):
196+
"""
197+
The actual kernel to compute C = A @ B^T.
198+
199+
Parameters
200+
----------
201+
block_M : int
202+
Block size in M dimension.
203+
block_N : int
204+
Block size in N dimension.
205+
block_K : int
206+
Block size in K dimension.
207+
num_stages : int
208+
Number of pipelined stages (for asynchronous load).
209+
thread_num : int
210+
Number of threads to use per block.
211+
enable_rasteration : bool
212+
Whether to enable rasterization (swizzling) optimization.
213+
k_pack : int
214+
K dimension packing factor to improve memory coalescing.
215+
216+
Returns
217+
-------
218+
Function
219+
A TVM Tensor Language function (T.prim_func) that computes matmul.
220+
"""
221+
# Use half-precision for input data to reduce memory bandwidth,
222+
# accumulate in float for better numerical accuracy
223+
dtype = "float16"
224+
accum_dtype = "float"
225+
226+
@T.prim_func
227+
def main(
228+
A: T.Buffer((M, K), dtype),
229+
B: T.Buffer((N, K), dtype),
230+
C: T.Buffer((M, N), dtype),
231+
):
232+
"""
233+
The compiled TVM function for block-level matrix multiplication.
234+
235+
- We divide the entire (M, N) domain into blocks of shape
236+
(block_M, block_N).
237+
- Each block has its own allocated shared memory for sub-blocks
238+
of A and B.
239+
- The partial results go into C_local, and then we copy them back
240+
to global memory C.
241+
"""
242+
# Bind x-dimension to block index in N,
243+
# y-dimension to block index in M.
244+
with T.Kernel(
245+
T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=thread_num) as (bx, by):
246+
247+
# Allocate shared memory for A sub-block of shape (block_M, block_K)
248+
A_shared = T.alloc_shared((block_M, block_K), dtype)
249+
# Allocate shared memory for B sub-block of shape (block_N, block_K)
250+
B_shared = T.alloc_shared((block_N, block_K), dtype)
251+
# Allocate a local fragment for intermediate accumulation
252+
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
253+
254+
# Enable (or disable) swizzling optimization
255+
T.use_swizzle(panel_size=10, enable=enable_rasteration)
256+
257+
# Clear out the accumulation buffer
258+
T.clear(C_local)
259+
260+
# Loop over sub-blocks in K dimension, pipelined by num_stages
261+
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
262+
# Load a sub-block of A from global memory into A_shared
263+
T.copy(
264+
A[by * block_M, k * block_K],
265+
A_shared,
266+
)
267+
# Load a sub-block of B from global memory into B_shared
268+
T.copy(
269+
B[bx * block_N, k * block_K],
270+
B_shared,
271+
)
272+
# Perform a partial matrix multiplication:
273+
# C_local += A_shared @ B_shared^T
274+
T.gemm(
275+
A_shared,
276+
B_shared,
277+
C_local,
278+
transpose_B=True,
279+
)
280+
# Write back the results from C_local to the global memory C
281+
T.copy(C_local, C[by * block_M, bx * block_N])
282+
283+
return main
284+
285+
return kernel()
286+
287+
288+
def test_autotune_get_configs():
289+
get_configs(8192, 8192, 8192, with_roller=False)
290+
291+
292+
def test_autotune_matmul():
293+
matmul(8192, 8192, 8192, with_roller=False)
294+
295+
296+
if __name__ == "__main__":
297+
tilelang.testing.main()

0 commit comments

Comments
 (0)