@@ -24,21 +24,32 @@ def attention_ref(
2424 dtype_og = q .dtype
2525 if upcast :
2626 q , k , v = q .float (), k .float (), v .float ()
27- dim = q .shape [- 1 ]
28- scale = (1.0 / dim )** 0.5
29- k = repeat (k , "b s h d -> b s (h g) d" , g = q .shape [2 ] // k .shape [2 ])
30- v = repeat (v , "b s h d -> b s (h g) d" , g = q .shape [2 ] // v .shape [2 ])
27+ b , T , Hq , D = q .shape
28+ S = k .shape [1 ]
29+ scale = (1.0 / D )** 0.5
30+ k = repeat (k , "b s h d -> b s (h g) d" , g = Hq // k .shape [2 ])
31+ v = repeat (v , "b s h d -> b s (h g) d" , g = Hq // v .shape [2 ])
3132 scores = torch .einsum ("bthd,bshd->bhts" , q , k )
33+ left , right = window_size
34+ left = S if left is None or left < 0 else int (left )
35+ right = S if right is None or right < 0 else int (right )
36+ t_idx = torch .arange (T , device = scores .device )[:, None ]
37+ s_idx = torch .arange (S , device = scores .device )[None , :]
38+ visible_ts = (s_idx >= (t_idx - left )) & (s_idx <= (t_idx + right ))
39+ visible_mask = visible_ts .unsqueeze (0 ).unsqueeze (0 )
3240 if key_padding_mask is not None :
33- scores .masked_fill_ (rearrange (~ key_padding_mask , "b s -> b 1 1 s" ), float ("-inf" ))
41+ k_keep = rearrange (key_padding_mask , "b s -> b 1 1 s" )
42+ visible_mask = visible_mask & k_keep
43+ neg_inf = torch .finfo (scores .dtype ).min
3444 scores = scores * scale
45+ scores = scores .masked_fill (~ visible_mask , neg_inf )
3546 attention = torch .softmax (scores , dim = - 1 ).to (v .dtype )
36-
3747 if query_padding_mask is not None :
38- attention = attention .masked_fill (rearrange (~ query_padding_mask , "b s -> b 1 s 1" ), 0.0 )
48+ q_keep = rearrange (query_padding_mask , "b t -> b 1 t 1" )
49+ attention = attention .masked_fill (~ q_keep , 0.0 )
3950 output = torch .einsum ("bhts,bshd->bthd" , attention , v )
4051 if query_padding_mask is not None :
41- output . masked_fill_ (rearrange (~ query_padding_mask , "b s -> b s 1 1" ), 0.0 )
52+ output = output . masked_fill (rearrange (~ query_padding_mask , "b t -> b t 1 1" ), 0.0 )
4253 return output .to (dtype = dtype_og ), attention .to (dtype = dtype_og )
4354
4455
@@ -91,60 +102,63 @@ def main(
91102 scores_sum = T .alloc_fragment ([block_M ], accum_dtype )
92103 logsum = T .alloc_fragment ([block_M ], accum_dtype )
93104
105+ T .annotate_layout ({
106+ O_shared : tilelang .layout .make_swizzled_layout (O_shared ),
107+ Q_shared : tilelang .layout .make_swizzled_layout (Q_shared ),
108+ })
109+
94110 batch_idx = bz
95111 head_idx = by
96112 kv_head_idx = head_idx // groups
97113
98114 q_start_idx = cu_seqlens_q [batch_idx ]
99- k_start_idx = cu_seqlens_k [batch_idx ]
100- v_start_idx = cu_seqlens_k [batch_idx ]
115+ kv_start_idx = cu_seqlens_k [batch_idx ]
101116 q_end_idx = cu_seqlens_q [batch_idx + 1 ]
102117 k_end_idx = cu_seqlens_k [batch_idx + 1 ]
103- v_end_idx = cu_seqlens_k [batch_idx + 1 ]
104118
105119 q_current_seqlen = q_end_idx - q_start_idx
106- k_current_seqlen = k_end_idx - k_start_idx
107- v_current_seqlen = v_end_idx - v_start_idx
120+ kv_current_seqlen = k_end_idx - kv_start_idx
108121
109122 T .copy (
110123 Q_unpad [q_start_idx + bx * block_M :q_start_idx + (bx + 1 ) * block_M , head_idx , :],
111124 Q_shared )
112- for i , d in T .Parallel (block_M , dim ):
113- if bx * block_M + i >= q_current_seqlen :
114- Q_shared [i , d ] = 0
115125
116126 T .fill (acc_o , 0 )
117127 T .fill (logsum , 0 )
118128 T .fill (scores_max , - T .infinity (accum_dtype ))
119129
120- loop_range = T .ceildiv (k_current_seqlen , block_N )
130+ loop_range = (
131+ T .min (
132+ T .ceildiv (q_current_seqlen +
133+ (bx + 1 ) * block_M , block_N ), T .ceildiv (kv_current_seqlen , block_N ))
134+ if is_causal else T .ceildiv (kv_current_seqlen , block_N ))
121135
122136 for k in T .Pipelined (loop_range , num_stages = num_stages ):
123137 T .copy (
124- K_unpad [k_start_idx + k * block_N :k_start_idx + (k + 1 ) * block_N ,
138+ K_unpad [kv_start_idx + k * block_N :kv_start_idx + (k + 1 ) * block_N ,
125139 kv_head_idx , :], K_shared )
126- for i , d in T .Parallel (block_N , dim ):
127- if k * block_N + i >= k_current_seqlen :
128- K_shared [i , d ] = 0
129140
130141 if is_causal :
131142 for i , j in T .Parallel (block_M , block_N ):
132- acc_s [i , j ] = T . if_then_else (( bx * block_M + i >= k * block_N + j ) and
133- ( bx * block_M + i >= q_current_seqlen or
134- k * block_N + j >= k_current_seqlen ),
135- - T . infinity ( acc_s . dtype ) , 0 )
143+ acc_s [i ,
144+ j ] = T . if_then_else (( bx * block_M + i < k * block_N + j ) or
145+ ( bx * block_M + i >= q_current_seqlen or
146+ k * block_N + j >= kv_current_seqlen ), - 1e9 , 0 )
136147 else :
137148 for i , j in T .Parallel (block_M , block_N ):
138149 acc_s [i , j ] = T .if_then_else ((bx * block_M + i >= q_current_seqlen or
139- k * block_N + j >= k_current_seqlen ) ,
140- - T . infinity ( acc_s . dtype ), 0 )
150+ k * block_N + j >= kv_current_seqlen ), - 1e9 ,
151+ 0 )
141152
142153 T .gemm (Q_shared , K_shared , acc_s , transpose_B = True , policy = T .GemmWarpPolicy .FullRow )
143154
144155 T .copy (scores_max , scores_max_prev )
145156 T .fill (scores_max , - T .infinity (accum_dtype ))
146157 T .reduce_max (acc_s , scores_max , dim = 1 , clear = False )
147158
159+ for i in T .Parallel (block_M ):
160+ scores_max [i ] = T .max (scores_max [i ], scores_max_prev [i ])
161+
148162 for i in T .Parallel (block_M ):
149163 scores_scale [i ] = T .exp2 (scores_max_prev [i ] * scale - scores_max [i ] * scale )
150164 for i , j in T .Parallel (block_M , block_N ):
@@ -158,11 +172,8 @@ def main(
158172 acc_o [i , j ] *= scores_scale [i ]
159173
160174 T .copy (
161- V_unpad [v_start_idx + k * block_N :v_start_idx + (k + 1 ) * block_N ,
175+ V_unpad [kv_start_idx + k * block_N :kv_start_idx + (k + 1 ) * block_N ,
162176 kv_head_idx , :], V_shared )
163- for i , d in T .Parallel (block_N , dim ):
164- if k * block_N + i >= v_current_seqlen :
165- V_shared [i , d ] = 0
166177
167178 T .gemm (acc_s_cast , V_shared , acc_o , policy = T .GemmWarpPolicy .FullRow )
168179
@@ -191,8 +202,7 @@ def main(batch: int = 1,
191202
192203 tilelang .testing .set_random_seed (0 )
193204
194- causal = False
195- if causal :
205+ if is_causal :
196206 total_flops *= 0.5
197207
198208 tilelang .testing .set_random_seed (0 )
@@ -201,9 +211,9 @@ def main(batch: int = 1,
201211 device = torch .device ("cuda" )
202212
203213 head_kv = heads // groups
204- q = torch .randn (batch , q_seqlen , heads , dim , dtype = dtype , device = device , requires_grad = True )
205- k = torch .randn (batch , k_seqlen , head_kv , dim , dtype = dtype , device = device , requires_grad = True )
206- v = torch .randn (batch , k_seqlen , head_kv , dim , dtype = dtype , device = device , requires_grad = True )
214+ q = torch .randn (batch , q_seqlen , heads , dim , dtype = dtype , device = device )
215+ k = torch .randn (batch , k_seqlen , head_kv , dim , dtype = dtype , device = device )
216+ v = torch .randn (batch , k_seqlen , head_kv , dim , dtype = dtype , device = device )
207217
208218 query_padding_mask = generate_random_padding_mask (q_seqlen , batch , device , mode = "random" )
209219 key_padding_mask = generate_random_padding_mask (k_seqlen , batch , device , mode = "random" )
@@ -236,10 +246,10 @@ def main(batch: int = 1,
236246 heads ,
237247 dim ,
238248 is_causal ,
239- block_M = 64 ,
240- block_N = 64 ,
241- num_stages = 1 ,
242- threads = 128 )
249+ block_M = 128 ,
250+ block_N = 128 ,
251+ num_stages = 2 ,
252+ threads = 256 )
243253
244254 out_unpad = kernel (q_unpad , k_unpad , v_unpad , cu_seqlens_q , cu_seqlens_k , max_seqlen_q )
245255 out = output_pad_fn (out_unpad )
@@ -255,7 +265,9 @@ def main(batch: int = 1,
255265 torch .testing .assert_close (out , out_ref , rtol = 1e-2 , atol = 1e-2 )
256266 print ("All checks passed.✅" )
257267 latency = do_bench (
258- lambda : kernel (q_unpad , k_unpad , v_unpad , cu_seqlens_q , cu_seqlens_k , max_seqlen_q ))
268+ lambda : kernel (q_unpad , k_unpad , v_unpad , cu_seqlens_q , cu_seqlens_k , max_seqlen_q ),
269+ _n_warmup = 5 ,
270+ _n_repeat = 5 )
259271 print ("Tile-lang: {:.2f} ms" .format (latency ))
260272 print ("Tile-lang: {:.2f} TFlops" .format (total_flops / latency * 1e-9 ))
261273
0 commit comments