11import itertools
2- import math
3- from einops import rearrange
42import tilelang
53from tilelang import language as T
64import torch
7- from tilelang .autotuner import autotune
85from tilelang import tvm
9- from utils import cal_cu_seqlen_ke_for_q , cal_cu_seqlen_ks_for_q
10-
6+ from utils import cal_cu_seqlen_ke_for_q , cal_cu_seqlen_ks_for_q , assert_similar
117from typing import Tuple
128
139
1410def ceil_to_ue8m0 (x : torch .Tensor ):
1511 assert x .view (- 1 ).amax ().item () > 0
1612 return torch .pow (2.0 , torch .ceil (torch .log2 (x .abs ())))
1713
18- def per_custom_dims_cast_to_fp8 (x : torch .Tensor , dims : Tuple [int ], use_ue8m0 : bool ) -> Tuple [torch .Tensor , torch .Tensor ]:
14+
15+ def per_custom_dims_cast_to_fp8 (x : torch .Tensor , dims : Tuple [int ],
16+ use_ue8m0 : bool ) -> Tuple [torch .Tensor , torch .Tensor ]:
1917 excluded_dims = tuple ([i for i in range (x .dim ()) if i not in set (dims )])
2018 x_amax = x .abs ().float ().amax (dim = excluded_dims , keepdim = True ).clamp (1e-4 )
2119 sf = x_amax / 448.0
2220 sf = ceil_to_ue8m0 (sf ) if use_ue8m0 else sf
2321 x_scaled = (x * (1.0 / sf )).to (torch .float8_e4m3fn )
2422 return x_scaled , sf .squeeze ()
2523
26- def print_red_warning (message ):
27- print (f"\033 [31mWARNING: { message } \033 [0m" )
28-
29-
30- def calc_sim (x , y , name = "tensor" ):
31- x , y = x .data .double (), y .data .double ()
32- denominator = (x * x + y * y ).sum ()
33- if denominator == 0 :
34- print_red_warning (f"{ name } all zero" )
35- return 1
36- sim = 2 * (x * y ).sum () / denominator
37- return sim
38-
39-
40- def assert_similar (x , y , eps = 1e-8 , name = "tensor" , raise_assert = True ):
41- x_mask = torch .isfinite (x )
42- y_mask = torch .isfinite (y )
43- if not torch .all (x_mask == y_mask ):
44- print_red_warning (f"{ name } Error: isfinite mask mismatch" )
45- if raise_assert :
46- assert False
47- if not torch .isclose (
48- x .masked_fill (x_mask , 0 ),
49- y .masked_fill (y_mask , 0 ),
50- rtol = 0 ,
51- atol = 0 ,
52- equal_nan = True ,
53- ).all ():
54- print_red_warning (f"{ name } Error: nonfinite value mismatch" )
55- if raise_assert :
56- assert False
57- x = x .masked_fill (~ x_mask , 0 )
58- y = y .masked_fill (~ y_mask , 0 )
59- sim = calc_sim (x , y , name )
60- diff = 1.0 - sim
61- if not (0 <= diff <= eps ):
62- print_red_warning (f"{ name } Error: { diff } " )
63- if raise_assert :
64- assert False
65- return diff
66-
6724
6825def get_configs ():
6926 iter_params = dict (
@@ -72,13 +29,13 @@ def get_configs():
7229 threads = [128 , 256 ],
7330 block_Q = [1 , 2 , 4 ],
7431 )
75- return [
76- {k : v for k , v in zip (iter_params , values )}
77- for values in itertools .product (* iter_params .values ())
78- ]
32+ return [{
33+ k : v for k , v in zip (iter_params , values )
34+ } for values in itertools .product (* iter_params .values ())]
7935
8036
8137class SupplyProg :
38+
8239 def __init__ (self ):
8340 self .tensors_dict = {}
8441
@@ -127,13 +84,13 @@ def mqa_attn_return_logits(
12784
12885 @T .prim_func
12986 def mqa_attn_return_logits_kernel (
130- IndexQ : T .Tensor (index_q_shape , dtype ), # type: ignore
131- IndexK : T .Tensor (index_k_shape , dtype ), # type: ignore
132- IndexKScale : T .Tensor (index_k_scale_shape , accum_dtype ), # type: ignore
133- Logits : T .Tensor (logits_shape , accum_dtype ), # type: ignore
134- Weights : T .Tensor ([seq_len , heads ], accum_dtype ), # type: ignore
135- CuSeqLenKS : T .Tensor ([seq_len ], index_dtype ), # type: ignore
136- CuSeqLenKE : T .Tensor ([seq_len ], index_dtype ), # type: ignore
87+ IndexQ : T .Tensor (index_q_shape , dtype ), # type: ignore
88+ IndexK : T .Tensor (index_k_shape , dtype ), # type: ignore
89+ IndexKScale : T .Tensor (index_k_scale_shape , accum_dtype ), # type: ignore
90+ Logits : T .Tensor (logits_shape , accum_dtype ), # type: ignore
91+ Weights : T .Tensor ([seq_len , heads ], accum_dtype ), # type: ignore
92+ CuSeqLenKS : T .Tensor ([seq_len ], index_dtype ), # type: ignore
93+ CuSeqLenKE : T .Tensor ([seq_len ], index_dtype ), # type: ignore
13794 ):
13895 with T .Kernel (T .ceildiv (seq_len , block_Q ), threads = threads ) as bx :
13996
@@ -156,20 +113,17 @@ def mqa_attn_return_logits_kernel(
156113 cu_k_e_max [0 ] = - 2147483648
157114
158115 for bq_i in T .serial (block_Q ):
159- cu_k_s_min [0 ] = T .min (
160- cu_k_s_min [0 ], T .min (CuSeqLenKS [seq_len_i + bq_i ], seq_len_kv )
161- )
116+ cu_k_s_min [0 ] = T .min (cu_k_s_min [0 ], T .min (CuSeqLenKS [seq_len_i + bq_i ],
117+ seq_len_kv ))
162118 for bq_i in T .serial (block_Q ):
163- cu_k_e_max [0 ] = T .max (
164- cu_k_e_max [0 ], T .min (CuSeqLenKE [seq_len_i + bq_i ], seq_len_kv )
165- )
119+ cu_k_e_max [0 ] = T .max (cu_k_e_max [0 ], T .min (CuSeqLenKE [seq_len_i + bq_i ],
120+ seq_len_kv ))
166121
167122 T .copy (IndexQ [seq_len_i * heads , 0 ], index_q_shared )
168123 T .copy (Weights [seq_len_i , 0 ], weights )
169124
170125 for nbn_i in T .Pipelined (
171- T .ceildiv (cu_k_e_max [0 ] - cu_k_s_min [0 ], block_N ), num_stages = num_stages
172- ):
126+ T .ceildiv (cu_k_e_max [0 ] - cu_k_s_min [0 ], block_N ), num_stages = num_stages ):
173127 T .copy (IndexK [cu_k_s_min [0 ] + nbn_i * block_N , 0 ], index_k_shared )
174128 T .copy (IndexKScale [cu_k_s_min [0 ] + nbn_i * block_N ], index_k_scale_fragment )
175129
@@ -183,16 +137,16 @@ def mqa_attn_return_logits_kernel(
183137 )
184138
185139 for bn_i , bq_i , h_i in T .Parallel (block_N , block_Q , heads ):
186- s_reshaped [bn_i , bq_i , h_i ] = (
187- T .max (s [bn_i , bq_i * heads + h_i ], 0 ) * weights [ bq_i , h_i ]
188- ) * index_k_scale_fragment [bn_i ]
140+ s_reshaped [bn_i , bq_i ,
141+ h_i ] = ( T .max (s [bn_i , bq_i * heads + h_i ], 0 ) *
142+ weights [ bq_i , h_i ]) * index_k_scale_fragment [bn_i ]
189143
190144 T .reduce_sum (s_reshaped , logits , dim = - 1 , clear = True )
191145
192146 for bq_i , bn_i in T .Parallel (block_Q , block_N ):
193147 Logits [seq_len_i + bq_i , cu_k_s_min [0 ] + nbn_i * block_N + bn_i ] = (
194- logits [bn_i , bq_i ]
195- )
148+ logits [bn_i , bq_i ])
149+
196150 return mqa_attn_return_logits_kernel
197151
198152
@@ -209,9 +163,9 @@ def clean_logits_(
209163
210164 @T .prim_func
211165 def clean_logits_kernel (
212- Logits : T .Tensor ([seq_len , seq_len_kv ], dtype ), # type: ignore
213- CuSeqLenKS : T .Tensor ([seq_len ], indices_dtype ), # type: ignore
214- CuSeqLenKE : T .Tensor ([seq_len ], indices_dtype ), # type: ignore
166+ Logits : T .Tensor ([seq_len , seq_len_kv ], dtype ), # type: ignore
167+ CuSeqLenKS : T .Tensor ([seq_len ], indices_dtype ), # type: ignore
168+ CuSeqLenKE : T .Tensor ([seq_len ], indices_dtype ), # type: ignore
215169 ):
216170 with T .Kernel (seq_len , threads = threads ) as bx :
217171 tx = T .thread_binding (0 , threads , thread = "threadIdx.x" )
@@ -229,17 +183,19 @@ def clean_logits_kernel(
229183 return clean_logits_kernel
230184
231185
232- def mqa_attn_return_logits_interface (
233- q , kv , kv_scales , weights , cu_seqlen_ks , cu_seqlen_ke , clean_logits = True
234- ):
186+ def mqa_attn_return_logits_interface (q ,
187+ kv ,
188+ kv_scales ,
189+ weights ,
190+ cu_seqlen_ks ,
191+ cu_seqlen_ke ,
192+ clean_logits = True ):
235193 seq_len , heads , index_dim = q .shape
236194 seq_len_kv = kv .shape [0 ]
237195
238196 clean_logits_kernel = clean_logits_ ()
239197
240- mqa_attn_return_logits_kernel = mqa_attn_return_logits (
241- heads = heads , index_dim = index_dim
242- )
198+ mqa_attn_return_logits_kernel = mqa_attn_return_logits (heads = heads , index_dim = index_dim )
243199 logits = torch .empty ([seq_len , seq_len_kv ], device = q .device , dtype = torch .float32 )
244200 mqa_attn_return_logits_kernel (
245201 q .view (seq_len * heads , index_dim ),
@@ -273,33 +229,30 @@ def ref_fp8_mqa_logits(q: torch.Tensor, kv: torch.Tensor, weights: torch.Tensor,
273229 cost = mask .sum ()
274230 return logits , cost
275231
232+
276233if __name__ == "__main__" :
277234 torch .manual_seed (0 )
278235 S , SKV , H , HKV , D , kv_stride = 4096 , 8192 , 32 , 1 , 64 , 1
279- q = torch .randn (S , H , D , device = "cuda" , dtype = torch .bfloat16 ).to (
280- torch .bfloat16
281- )
282- kv = torch .randn (SKV , D , device = "cuda" , dtype = torch .bfloat16 ).to (
283- torch .bfloat16
284- )
236+ q = torch .randn (S , H , D , device = "cuda" , dtype = torch .bfloat16 ).to (torch .bfloat16 )
237+ kv = torch .randn (SKV , D , device = "cuda" , dtype = torch .bfloat16 ).to (torch .bfloat16 )
285238 weights = torch .randn (S , H , device = "cuda" , dtype = torch .float32 )
286239 p = (torch .randn (S , SKV , device = "cuda" , dtype = torch .float32 ) * 4 ).softmax (dim = - 1 )
287240
288- def generate_random_cu_seqlens (
289- per_cp_seqlen , cp_size = 4 , cp_rank = 3 , kv_stride = 1 , average_q_len = 512
290- ):
241+ def generate_random_cu_seqlens (per_cp_seqlen ,
242+ cp_size = 4 ,
243+ cp_rank = 3 ,
244+ kv_stride = 1 ,
245+ average_q_len = 512 ):
291246 total_seqlen = per_cp_seqlen * cp_size
292247
293- cu_seqlens = torch .randint (
294- 0 , average_q_len * 2 , (total_seqlen // average_q_len * 2 ,)
295- ).cuda ()
248+ cu_seqlens = torch .randint (0 , average_q_len * 2 ,
249+ (total_seqlen // average_q_len * 2 ,)).cuda ()
296250 last_seq_id = torch .where (cu_seqlens .cumsum (0 ) >= total_seqlen )[0 ][0 ]
297251 cu_seqlens = cu_seqlens [:last_seq_id ]
298252
299253 if cu_seqlens .sum () < total_seqlen :
300254 cu_seqlens = torch .cat (
301- [cu_seqlens , torch .tensor ([total_seqlen - cu_seqlens .sum ()]).cuda ()]
302- )
255+ [cu_seqlens , torch .tensor ([total_seqlen - cu_seqlens .sum ()]).cuda ()])
303256
304257 total_seqlen_k = (cu_seqlens // kv_stride ).sum ()
305258
@@ -328,75 +281,65 @@ def generate_random_cu_seqlens(
328281
329282 assert per_cp_seqlen % 2 == 0
330283 per_chunk_seqlen = per_cp_seqlen // 2
331- slice_short = slice (
332- cp_rank * per_chunk_seqlen , (cp_rank + 1 ) * per_chunk_seqlen
333- )
284+ slice_short = slice (cp_rank * per_chunk_seqlen , (cp_rank + 1 ) * per_chunk_seqlen )
334285 slice_long = slice (
335286 total_seqlen - (cp_rank + 1 ) * per_chunk_seqlen ,
336287 total_seqlen - cp_rank * per_chunk_seqlen ,
337288 )
338- ks = torch .cat (
339- [
340- cu_seqlens_ks_for_each_q [slice_short ],
341- cu_seqlens_ks_for_each_q [slice_long ],
342- ]
343- )
344- ke = torch .cat (
345- [
346- cu_seqlens_ke_for_each_q [slice_short ],
347- cu_seqlens_ke_for_each_q [slice_long ],
348- ]
349- )
289+ ks = torch .cat ([
290+ cu_seqlens_ks_for_each_q [slice_short ],
291+ cu_seqlens_ks_for_each_q [slice_long ],
292+ ])
293+ ke = torch .cat ([
294+ cu_seqlens_ke_for_each_q [slice_short ],
295+ cu_seqlens_ke_for_each_q [slice_long ],
296+ ])
350297 assert len (ks ) == len (ke ) == per_cp_seqlen
351298 return ks , ke
352299
353300 ks , ke = generate_random_cu_seqlens (
354- per_cp_seqlen = S , cp_size = 4 , cp_rank = 3 , kv_stride = kv_stride , average_q_len = 2048
355- )
301+ per_cp_seqlen = S , cp_size = 4 , cp_rank = 3 , kv_stride = kv_stride , average_q_len = 2048 )
356302
357303 logits_ref , cost_ref = ref_fp8_mqa_logits (
358- q = q , kv = kv , weights = weights , cu_seqlen_ks = ks , cu_seqlen_ke = ke
359- )
360-
304+ q = q , kv = kv , weights = weights , cu_seqlen_ks = ks , cu_seqlen_ke = ke )
305+
361306 q_fp8 = q .to (torch .float8_e4m3fn )
362- kv_fp8 , kv_scales = per_custom_dims_cast_to_fp8 (kv , (0 , ), False )
307+ kv_fp8 , kv_scales = per_custom_dims_cast_to_fp8 (kv , (0 ,), False )
363308
364309 logits_tl = mqa_attn_return_logits_interface (
365- q = q_fp8 , kv = kv_fp8 , kv_scales = kv_scales , weights = weights , cu_seqlen_ks = ks , cu_seqlen_ke = ke
366- )
367- diff = assert_similar (
368- logits_ref , logits_tl , eps = 1e-14 , name = "logits" , raise_assert = False
369- )
310+ q = q_fp8 , kv = kv_fp8 , kv_scales = kv_scales , weights = weights , cu_seqlen_ks = ks , cu_seqlen_ke = ke )
311+ diff = assert_similar (logits_ref , logits_tl , eps = 1e-14 , name = "logits" , raise_assert = False )
370312
371313 original_diff = None
372314 for i in range (10 ):
373315 logits_tl = mqa_attn_return_logits_interface (
374- q = q_fp8 , kv = kv_fp8 , kv_scales = kv_scales , weights = weights , cu_seqlen_ks = ks , cu_seqlen_ke = ke
375- )
376- diff = assert_similar (
377- logits_ref , logits_tl , eps = 1e-14 , name = "logits" , raise_assert = False
378- )
316+ q = q_fp8 ,
317+ kv = kv_fp8 ,
318+ kv_scales = kv_scales ,
319+ weights = weights ,
320+ cu_seqlen_ks = ks ,
321+ cu_seqlen_ke = ke )
322+ diff = assert_similar (logits_ref , logits_tl , eps = 1e-14 , name = "logits" , raise_assert = False )
379323 if original_diff is None :
380324 original_diff = diff
381325 else :
382326 assert original_diff == diff
383327
384- from tilelang .profiler import do_bench
385-
328+ from tilelang .profiler import do_bench
386329
387330 def logits_fn ():
388331 return mqa_attn_return_logits_interface (
389- q = q_fp8 , kv = kv_fp8 , kv_scales = kv_scales , weights = weights , cu_seqlen_ks = ks , cu_seqlen_ke = ke
390- )
391-
392- with torch .profiler .profile (
393- activities = [torch .profiler .ProfilerActivity .CUDA ]
394- ) as prof :
332+ q = q_fp8 ,
333+ kv = kv_fp8 ,
334+ kv_scales = kv_scales ,
335+ weights = weights ,
336+ cu_seqlen_ks = ks ,
337+ cu_seqlen_ke = ke )
338+
339+ with torch .profiler .profile (activities = [torch .profiler .ProfilerActivity .CUDA ]) as prof :
395340 logits_fn ()
396341
397- print (
398- prof .key_averages ().table (sort_by = "cuda_time_total" , max_name_column_width = 50 )
399- )
342+ print (prof .key_averages ().table (sort_by = "cuda_time_total" , max_name_column_width = 50 ))
400343
401344 logits_ms = do_bench (logits_fn , warmup = 100 , rep = 100 )
402345 logits_flops = 2 * cost_ref * H * D
0 commit comments