33import tilelang
44from tilelang .autotuner import *
55import tilelang .language as T
6+ from einops import rearrange , einsum
67
7- num_split = 4
8+ num_split = 1
89
910
1011def flashattn (batch , heads , kv_head_num , seqlen_kv , dim , pe_dim , block_N , block_H ):
1112 scale = (1.0 / (dim + pe_dim ))** 0.5 * 1.44269504 # log2(e)
12- shape_q = [batch , heads , (dim + pe_dim )]
13- shape_k = [batch , seqlen_kv , kv_head_num , (dim + pe_dim )]
14- shape_v = [batch , seqlen_kv , kv_head_num , dim ]
15- shape_o = [batch , heads , dim ]
16- part_shape = [batch , heads , num_split , dim ]
1713 dtype = "float16"
1814 accum_dtype = "float"
1915 kv_group_num = heads // kv_head_num
@@ -22,19 +18,23 @@ def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_
2218
2319 @T .macro
2420 def flash_attn_split (
25- Q : T .Buffer (shape_q , dtype ),
26- K : T .Buffer (shape_k , dtype ),
27- V : T .Buffer (shape_v , dtype ),
21+ Q : T .Buffer ([batch , heads , dim ], dtype ),
22+ Q_pe : T .Buffer ([batch , heads , pe_dim ], dtype ),
23+ KV : T .Buffer ([batch , seqlen_kv , kv_head_num , dim ], dtype ),
24+ K_pe : T .Buffer ([batch , seqlen_kv , kv_head_num , pe_dim ], dtype ),
2825 glse : T .Buffer ([batch , heads , num_split ], dtype ),
29- Output_partial : T .Buffer (part_shape , dtype ),
26+ Output_partial : T .Buffer ([ batch , heads , num_split , dim ] , dtype ),
3027 ):
3128 with T .Kernel (
32- batch , heads // min (block_H , kv_group_num ), num_split , threads = 128 ) as (bx , by , bz ):
33- Q_shared = T .alloc_shared ([block_H , (dim + pe_dim )], dtype )
34- K_shared = T .alloc_shared ([block_N , (dim + pe_dim )], dtype )
35- V_shared = T .alloc_shared ([block_N , dim ], dtype )
29+ batch , heads // min (block_H , kv_group_num ), num_split , threads = 256 ) as (bx , by , bz ):
30+ Q_shared = T .alloc_shared ([block_H , dim ], dtype )
31+ S_shared = T .alloc_shared ([block_H , block_N ], dtype )
32+ Q_pe_shared = T .alloc_shared ([block_H , pe_dim ], dtype )
33+ KV_shared = T .alloc_shared ([block_N , dim ], dtype )
34+ K_pe_shared = T .alloc_shared ([block_N , pe_dim ], dtype )
3635 O_shared = T .alloc_shared ([block_H , dim ], dtype )
3736 acc_s = T .alloc_fragment ([block_H , block_N ], accum_dtype )
37+ acc_s_0 = T .alloc_fragment ([block_H , block_N ], accum_dtype )
3838 acc_s_cast = T .alloc_fragment ([block_H , block_N ], dtype )
3939 acc_o = T .alloc_fragment ([block_H , dim ], accum_dtype )
4040 scores_max = T .alloc_fragment ([block_H ], accum_dtype )
@@ -53,20 +53,32 @@ def flash_attn_split(
5353 })
5454
5555 T .copy (Q [bid , hid * VALID_BLOCK_H :(hid + 1 ) * VALID_BLOCK_H , :], Q_shared )
56+ T .copy (Q_pe [bid , hid * VALID_BLOCK_H :(hid + 1 ) * VALID_BLOCK_H , :], Q_pe_shared )
5657 T .fill (acc_o , 0 )
5758 T .fill (logsum , 0 )
5859 T .fill (scores_max , - T .infinity (accum_dtype ))
5960
6061 loop_range = T .ceildiv ((seqlen_kv // num_split ), block_N )
61- for k in T .Pipelined (loop_range , num_stages = 1 ):
62+ for k in T .Pipelined (loop_range , num_stages = 2 ):
63+ kv_start = (seqlen_kv // num_split ) * sid + k * block_N
64+ kv_end = (seqlen_kv // num_split ) * sid + (k + 1 ) * block_N
65+
6266 T .copy (
63- K [bid , (seqlen_kv // num_split ) * sid +
64- k * block_N :(seqlen_kv // num_split ) * sid + (k + 1 ) * block_N ,
65- cur_kv_head , :], K_shared )
66- T .clear (acc_s )
67- T .gemm (Q_shared , K_shared , acc_s , transpose_B = True , policy = T .GemmWarpPolicy .FullRow )
67+ KV [bid , kv_start :kv_end , cur_kv_head , :],
68+ KV_shared
69+ )
70+ T .copy (
71+ K_pe [bid , kv_start :kv_end , cur_kv_head , :],
72+ K_pe_shared
73+ )
74+
75+ T .clear (acc_s_0 )
76+ T .gemm (Q_shared , KV_shared , acc_s_0 , transpose_B = True , policy = T .GemmWarpPolicy .FullCol )
77+ T .gemm (Q_pe_shared , K_pe_shared , acc_s_0 , transpose_B = True , policy = T .GemmWarpPolicy .FullCol )
6878 T .copy (scores_max , scores_max_prev )
6979 T .fill (scores_max , - T .infinity (accum_dtype ))
80+ T .copy (acc_s_0 , S_shared )
81+ T .copy (S_shared , acc_s )
7082 T .reduce_max (acc_s , scores_max , dim = 1 , clear = False )
7183 for i in T .Parallel (block_H ):
7284 scores_scale [i ] = T .exp2 (scores_max_prev [i ] * scale - scores_max [i ] * scale )
@@ -78,11 +90,7 @@ def flash_attn_split(
7890 T .copy (acc_s , acc_s_cast )
7991 for i , j in T .Parallel (block_H , dim ):
8092 acc_o [i , j ] *= scores_scale [i ]
81- T .copy (
82- V [bid , (seqlen_kv // num_split ) * sid +
83- k * block_N :(seqlen_kv // num_split ) * sid + (k + 1 ) * block_N ,
84- cur_kv_head , :], V_shared )
85- T .gemm (acc_s_cast , V_shared , acc_o , policy = T .GemmWarpPolicy .FullRow )
93+ T .gemm (acc_s_cast , KV_shared , acc_o , policy = T .GemmWarpPolicy .FullCol )
8694 for i , j in T .Parallel (block_H , dim ):
8795 acc_o [i , j ] /= logsum [i ]
8896 for i in T .Parallel (block_H ):
@@ -96,8 +104,8 @@ def flash_attn_split(
96104 @T .macro
97105 def combine (
98106 glse : T .Buffer ([batch , heads , num_split ], dtype ),
99- Output_partial : T .Buffer (part_shape , dtype ),
100- Output : T .Buffer (shape_o , dtype ),
107+ Output_partial : T .Buffer ([ batch , heads , num_split , dim ] , dtype ),
108+ Output : T .Buffer ([ batch , heads , dim ] , dtype ),
101109 ):
102110 with T .Kernel (heads , batch , threads = 128 ) as (by , bz ):
103111 po_local = T .alloc_fragment ([dim ], dtype )
@@ -133,50 +141,63 @@ def combine(
133141
134142 @T .prim_func
135143 def main (
136- Q : T .Buffer (shape_q , dtype ),
137- K : T .Buffer (shape_k , dtype ),
138- V : T .Buffer (shape_v , dtype ),
144+ Q : T .Buffer ([batch , heads , dim ], dtype ),
145+ Q_pe : T .Buffer ([batch , heads , pe_dim ], dtype ),
146+ KV : T .Buffer ([batch , seqlen_kv , kv_head_num , dim ], dtype ),
147+ K_pe : T .Buffer ([batch , seqlen_kv , kv_head_num , pe_dim ], dtype ),
139148 glse : T .Buffer ([batch , heads , num_split ], dtype ),
140- Output_partial : T .Buffer (part_shape , dtype ), # [batch, heads, num_split, dim]
141- Output : T .Buffer (shape_o , dtype ),
149+ Output_partial : T .Buffer ([batch , heads , num_split , dim ], dtype ),
150+ Output : T .Buffer ([ batch , heads , dim ] , dtype ),
142151 ):
143- flash_attn_split (Q , K , V , glse , Output_partial )
152+ flash_attn_split (Q , Q_pe , KV , K_pe , glse , Output_partial )
144153 combine (glse , Output_partial , Output )
145154
146155 return main
147156
148157
149- def ref_program (query , key , value , glse , Output_partial ):
158+
159+ def ref_program (q , q_pe , kv , k_pe , glse , Output_partial ):
150160 # """
151161 # Inputs:
152- # - query (Tensor): [batch, heads, dim]
153- # - key (Tensor): [batch, seqlen_kv, kv_head_num, dim]
154- # - value (Tensor): [batch, seqlen_kv, kv_head_num, dim]
155-
162+ # - q (Tensor): [batch, heads, dim]
163+ # - q_pe (Tensor): [batch, heads, pe_dim]
164+ # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim]
165+ # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim]
166+ # - glse (Tensor): [batch, heads, num_split]
167+ # - Output_partial (Tensor): [batch, heads, num_split, dim]
156168 # Outputs:
157169 # - output (Tensor): [batch, heads, dim]
158170 # """
159- from einops import rearrange
160- batch_size , query_heads , dim = query .shape # [batch_size, query_heads, dim]
161- _ , seqlen_kv , kv_heads , _ = key .shape # [batch_size, seqlen_kv, kv_heads, kv_dim]
162- dim_v = value .shape [- 1 ]
163- assert kv_heads == 1 , "kv_heads must be 1"
164-
165- query_expanded = rearrange (query , 'b h d -> b h 1 d' ) # [batch_size, query_heads, 1, dim]
166- key_expanded = key .expand (- 1 , - 1 , query_heads , - 1 ) # [batch_size, query_heads, seqlen_kv, dim]
167- value_expanded = value .expand (- 1 , - 1 , query_heads ,
168- - 1 ) # [batch_size, query_heads, seqlen_kv, dim]
169- key_expanded = rearrange (key_expanded ,
170- 'b n h d -> b h n d' ) # [batch_size, kv_head_num, seqlen_kv, dim]
171- value_expanded = rearrange (value_expanded ,
172- 'b n h d -> b h n d' ) # [batch_size, query_heads, seqlen_kv, dim]
173-
174- scores = torch .matmul (query_expanded ,
175- key_expanded .transpose (- 1 , - 2 )) # [batch_size, query_heads, 1, seqlen_kv]
176- scores = scores / torch .sqrt (torch .tensor (dim , dtype = scores .dtype ))
177- attention_weights = F .softmax (scores , dim = - 1 ) # [batch_size, query_heads, 1, seqlen_kv]
178- output = torch .matmul (attention_weights , value_expanded ) # [batch_size, query_heads, 1, dim]
179- return output .view (batch_size , query_heads , dim_v )
171+ dim = q .shape [- 1 ]
172+ pe_dim = q_pe .shape [- 1 ]
173+ num_head_groups = q .shape [1 ] // kv .shape [2 ]
174+ scale = (dim + pe_dim ) ** 0.5
175+ q = rearrange (
176+ q , 'b (h g) d -> b g h d' ,
177+ g = num_head_groups ) # [batch_size, num_head_groups, groups, dim]
178+
179+ q_pe = rearrange (
180+ q_pe , 'b (h g) d -> b g h d' ,
181+ g = num_head_groups ) # [batch_size, num_head_groups, groups, pe_dim]
182+
183+ kv = rearrange (kv , 'b n h d -> b h n d' ) # [batch_size, groups, seqlen_kv, dim]
184+
185+ k_pe = rearrange (k_pe , 'b n h d -> b h n d' ) # [batch_size, num_head_groups, groups, pe_dim]
186+
187+ query = torch .concat ([q , q_pe ], dim = - 1 )
188+ key = torch .concat ([kv , k_pe ], dim = - 1 )
189+
190+ scores = einsum (
191+ query , key ,
192+ 'b g h d, b h s d -> b g h s' ) # [batch_size, num_head_groups, groups, seqlen_kv]
193+
194+ attention = F .softmax (
195+ scores / scale , dim = - 1 ) # [batch_size, num_head_groups, groups, seqlen_kv]
196+
197+ out = einsum (attention , kv ,
198+ 'b g h s, b h s d -> b g h d' ) # [batch_size, num_head_groups, groups, dim]
199+ out = rearrange (out , 'b g h d -> b (h g) d' ) # [batch_size, heads, dim]
200+ return out
180201
181202
182203def flash_split_ref (Q , K , V ):
@@ -251,7 +272,7 @@ def reduce_ref(Q, K, V, glse, Output_partial):
251272
252273
253274if __name__ == "__main__" :
254- BATCH , H_Q , KV_H , KV_CTX , D_HEAD , DPE = 64 , 128 , 1 , 8192 , 512 , 64
275+ BATCH , H_Q , KV_H , KV_CTX , D_HEAD , DPE = 128 , 128 , 1 , 8192 , 512 , 64
255276 qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE )
256277 pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD
257278 total_flops = qk_flops + pv_flops
@@ -260,8 +281,9 @@ def reduce_ref(Q, K, V, glse, Output_partial):
260281
261282 program = flashattn (BATCH , H_Q , KV_H , KV_CTX , D_HEAD , DPE , BLOCK_N , BLOCK_H )
262283 mod , params = tilelang .lower (program )
263- mod = tilelang .Profiler (mod , params , [5 ], tilelang .TensorSupplyType .Normal )
284+ mod = tilelang .Profiler (mod , params , [6 ], tilelang .TensorSupplyType .Normal )
264285 mod .assert_allclose (ref_program , rtol = 0.01 , atol = 0.01 )
265- latency = mod .do_bench (mod .func , warmup = 500 )
286+ print ("All close" )
287+ latency = mod .do_bench (mod .func , n_warmup = 10 , n_repeat = 10 , profiler = "torch" )
266288 print ("Tile-lang: {:.2f} ms" .format (latency ))
267289 print ("Tile-lang: {:.2f} TFlops" .format (total_flops / latency * 1e-9 ))
0 commit comments