33import itertools
44import tilelang
55import tilelang .language as T
6- from tilelang .autotuner import AutoTuner
7- from tilelang .carver .template import ConvTemplate
8- from tilelang .carver .arch import CUDA
9- from tilelang .carver .arch import CDNA
10- from tilelang .carver .roller .rasterization import NoRasterization
116
127
138def check_hopper ():
@@ -30,149 +25,36 @@ def main(A, B):
3025 return main
3126
3227
33- def get_configs (N , C , H , W , F , K , S , D , P , with_roller = False , topk = 15 ):
34- if with_roller :
35- arch = CDNA ("hip" ) if torch .version .hip is not None else CUDA ("cuda" )
36- carve_template = ConvTemplate (
37- N = N ,
38- C = C ,
39- H = H ,
40- W = W ,
41- F = F ,
42- K = K ,
43- S = S ,
44- D = D ,
45- P = P ,
46- in_dtype = "float16" ,
47- out_dtype = "float16" ,
48- accum_dtype = "float" ,
49- ).with_arch (arch )
50-
51- func = carve_template .equivalent_function ()
52- assert func is not None , "Function is None"
53- roller_hints = carve_template .recommend_hints (topk = topk )
54- if roller_hints is None :
55- raise ValueError ("No Roller Hints Found for TensorCore Scheduling" )
56- configs = []
57- for hint in roller_hints :
58- config = {}
59- block_m , block_n = hint .block
60- warp_m , warp_n = hint .warp
61- # block_rows, block_cols represents warp partitioning
62- block_rows , block_cols = block_m // warp_m , block_n // warp_n
63- config ["block_M" ] = block_m
64- config ["block_N" ] = block_n
65- config ["block_K" ] = hint .rstep [0 ]
66- config ["num_stages" ] = hint .pipeline_stage if hint .pipeline_stage > 1 else 0
67- config ["thread_num" ] = block_rows * block_cols * 32
68- config ["enable_rasteration" ] = hint .rasterization_plan is not NoRasterization
69- configs .append (config )
70- else :
71- block_M = [64 , 128 , 256 ]
72- block_N = [64 , 128 , 256 ]
73- block_K = [32 , 64 ]
74- num_stages = [0 , 1 , 2 , 3 ]
75- thread_num = [128 , 256 ]
76- enable_rasterization = [True , False ]
77- _configs = list (
78- itertools .product (
79- block_M ,
80- block_N ,
81- block_K ,
82- num_stages ,
83- thread_num ,
84- enable_rasterization ,
85- ))
86-
87- configs = [
88- {
89- "block_M" : c [0 ],
90- "block_N" : c [1 ],
91- "block_K" : c [2 ],
92- "num_stages" : c [3 ],
93- "thread_num" : c [4 ],
94- "enable_rasteration" : c [5 ], # keep param name for backward-compat
95- } for c in _configs
96- ]
28+ def get_configs ():
29+ block_M = [64 , 128 , 256 ]
30+ block_N = [64 , 128 , 256 ]
31+ block_K = [32 , 64 ]
32+ num_stages = [0 , 1 , 2 , 3 ]
33+ thread_num = [128 , 256 ]
34+ enable_rasterization = [True , False ]
35+ _configs = list (
36+ itertools .product (
37+ block_M ,
38+ block_N ,
39+ block_K ,
40+ num_stages ,
41+ thread_num ,
42+ enable_rasterization ,
43+ ))
44+
45+ configs = [
46+ {
47+ "block_M" : c [0 ],
48+ "block_N" : c [1 ],
49+ "block_K" : c [2 ],
50+ "num_stages" : c [3 ],
51+ "thread_num" : c [4 ],
52+ "enable_rasteration" : c [5 ], # keep param name for backward-compat
53+ } for c in _configs
54+ ]
9755 return configs
9856
9957
100- def get_best_config (N , C , H , W , F , K , S , D , P , ref_prog , with_roller = False ):
101-
102- @tilelang .jit (out_idx = [2 ])
103- def kernel (
104- block_M = None ,
105- block_N = None ,
106- block_K = None ,
107- num_stages = None ,
108- thread_num = None ,
109- enable_rasteration = None ,
110- ):
111- dtype = "float16"
112- accum_dtype = "float"
113- KH , KW = K , K
114- OH = (H + 2 * P - D * (K - 1 ) - 1 ) // S + 1
115- OW = (W + 2 * P - D * (K - 1 ) - 1 ) // S + 1
116- is_hopper = check_hopper ()
117-
118- @T .prim_func
119- def main (
120- data : T .Tensor ((N , H , W , C ), dtype ),
121- kernel : T .Tensor ((KH , KW , C , F ), dtype ),
122- out : T .Tensor ((N , OH , OW , F ), dtype ),
123- ):
124- with T .Kernel (
125- T .ceildiv (F , block_N ), T .ceildiv (N * OH * OW , block_M ),
126- threads = thread_num ) as (bx , by ):
127- data_shared = T .alloc_shared ((block_M , block_K ), dtype )
128- kernel_shared = T .alloc_shared ((block_K , block_N ), dtype )
129- out_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
130- out_shared = T .alloc_shared ((block_M , block_N ), dtype )
131-
132- kernel_flat = T .Tensor ((KH * KW * C , F ), dtype , kernel .data )
133- out_flat = T .Tensor ((N * OH * OW , F ), dtype , out .data )
134-
135- T .annotate_layout ({
136- out_shared : tilelang .layout .make_swizzled_layout (out_shared ),
137- data_shared : tilelang .layout .make_swizzled_layout (data_shared ),
138- kernel_shared : tilelang .layout .make_swizzled_layout (kernel_shared ),
139- })
140-
141- T .clear (out_local )
142- for k_iter in T .Pipelined (T .ceildiv (KH * KW * C , block_K ), num_stages = num_stages ):
143- if is_hopper :
144- T .c2d_im2col (data , data_shared , by , k_iter , KH , S , D , P )
145- else :
146- for i , j in T .Parallel (block_M , block_K ):
147- k = k_iter * block_K + j
148- m = by * block_M + i
149- access_h = m % (OH * OW ) // OW * S + k // (KW * C ) * D - P
150- access_w = m % OW * S + k // C % KW * D - P
151- in_bound = ((access_h >= 0 ) and (access_w >= 0 ) and (access_h < H ) and
152- (access_w < W ))
153- data_shared [i , j ] = T .if_then_else (
154- in_bound , data [m // (OH * OW ), access_h , access_w , k % C ], 0 )
155- T .copy (kernel_flat [k_iter * block_K , bx * block_N ], kernel_shared )
156- T .gemm (data_shared , kernel_shared , out_local )
157-
158- T .copy (out_local , out_shared )
159- T .copy (out_shared , out_flat [by * block_M , bx * block_N ])
160-
161- return main
162-
163- autotuner = AutoTuner .from_kernel (
164- kernel = kernel , configs = get_configs (N , C , H , W , F , K , S , D , P ,
165- with_roller )).set_compile_args (
166- out_idx = [2 ],
167- target = "auto" ,
168- ).set_profile_args (
169- supply_type = tilelang .TensorSupplyType .Integer ,
170- ref_prog = ref_prog ,
171- skip_check = False ,
172- )
173- return autotuner .run (warmup = 3 , rep = 20 )
174-
175-
17658def get_heuristic_config () -> dict :
17759 # Get CUDA device properties
17860 if not torch .cuda .is_available ():
@@ -210,6 +92,7 @@ def get_heuristic_config() -> dict:
21092 }
21193
21294
95+ @tilelang .autotune (configs = get_configs ())
21396@tilelang .jit (out_idx = [2 ])
21497def convolution (N ,
21598 C ,
@@ -252,11 +135,10 @@ def main(
252135 kernel_flat = T .Tensor ((KH * KW * C , F ), dtype , kernel .data )
253136 out_flat = T .Tensor ((N * OH * OW , F ), dtype , out .data )
254137
255- T .annotate_layout ({
256- out_shared : tilelang .layout .make_swizzled_layout (out_shared ),
257- data_shared : tilelang .layout .make_swizzled_layout (data_shared ),
258- kernel_shared : tilelang .layout .make_swizzled_layout (kernel_shared ),
259- })
138+ if is_hopper :
139+ T .annotate_layout ({
140+ out_shared : tilelang .layout .make_swizzled_layout (out_shared ),
141+ })
260142
261143 T .clear (out_local )
262144 for k_iter in T .Pipelined (T .ceildiv (KH * KW * C , block_K ), num_stages = num_stages ):
@@ -275,8 +157,11 @@ def main(
275157 T .copy (kernel_flat [k_iter * block_K , bx * block_N ], kernel_shared )
276158 T .gemm (data_shared , kernel_shared , out_local )
277159
278- T .copy (out_local , out_shared )
279- T .copy (out_shared , out_flat [by * block_M , bx * block_N ])
160+ if is_hopper :
161+ T .copy (out_local , out_shared )
162+ T .copy (out_shared , out_flat [by * block_M , bx * block_N ])
163+ else :
164+ T .copy (out_local , out_flat [by * block_M , bx * block_N ])
280165
281166 return main
282167
@@ -296,9 +181,7 @@ def main(n: int = 128,
296181 ref_prog = ref_program (S , P , D )
297182
298183 if use_autotune :
299- result = get_best_config (N , C , H , W , F , K , S , D , P , ref_prog , with_roller )
300- print (result .config )
301- kernel = result .kernel
184+ kernel = convolution (N , C , H , W , F , K , S , D , P )
302185 else :
303186 config = get_heuristic_config ()
304187 kernel = convolution (N , C , H , W , F , K , S , D , P , ** config )
0 commit comments