1+ import torch
2+ import torch .nn .functional as F
3+ import tilelang
4+ from tilelang .autotuner import *
5+ import tilelang .language as T
6+
7+ num_split = 4
8+
9+
10+ def flashattn (batch , heads , kv_head_num , seqlen_kv , dim , pe_dim , block_N , block_H ):
11+ scale = (1.0 / (dim + pe_dim ))** 0.5 * 1.44269504 # log2(e)
12+ shape_q = [batch , heads , (dim + pe_dim )]
13+ shape_k = [batch , seqlen_kv , kv_head_num , (dim + pe_dim )]
14+ shape_v = [batch , seqlen_kv , kv_head_num , dim ]
15+ shape_o = [batch , heads , dim ]
16+ part_shape = [batch , heads , num_split , dim ]
17+ dtype = "float16"
18+ accum_dtype = "float"
19+ kv_group_num = heads // kv_head_num
20+ VALID_BLOCK_H = min (block_H , kv_group_num )
21+ assert kv_head_num == 1 , "kv_head_num must be 1"
22+
23+ @T .macro
24+ def flash_attn_split (
25+ Q : T .Buffer (shape_q , dtype ),
26+ K : T .Buffer (shape_k , dtype ),
27+ V : T .Buffer (shape_v , dtype ),
28+ glse : T .Buffer ([batch , heads , num_split ], dtype ),
29+ Output_partial : T .Buffer (part_shape , dtype ),
30+ ):
31+ with T .Kernel (
32+ batch , heads // min (block_H , kv_group_num ), num_split , threads = 128 ) as (bx , by , bz ):
33+ Q_shared = T .alloc_shared ([block_H , (dim + pe_dim )], dtype )
34+ K_shared = T .alloc_shared ([block_N , (dim + pe_dim )], dtype )
35+ V_shared = T .alloc_shared ([block_N , dim ], dtype )
36+ O_shared = T .alloc_shared ([block_H , dim ], dtype )
37+ acc_s = T .alloc_fragment ([block_H , block_N ], accum_dtype )
38+ acc_s_cast = T .alloc_fragment ([block_H , block_N ], dtype )
39+ acc_o = T .alloc_fragment ([block_H , dim ], accum_dtype )
40+ scores_max = T .alloc_fragment ([block_H ], accum_dtype )
41+ scores_max_prev = T .alloc_fragment ([block_H ], accum_dtype )
42+ scores_scale = T .alloc_fragment ([block_H ], accum_dtype )
43+ scores_sum = T .alloc_fragment ([block_H ], accum_dtype )
44+ logsum = T .alloc_fragment ([block_H ], accum_dtype )
45+
46+ bid = bx
47+ hid = by
48+ sid = bz
49+ cur_kv_head = hid // (kv_group_num // block_H )
50+
51+ T .annotate_layout ({
52+ O_shared : tilelang .layout .make_swizzled_layout (O_shared ),
53+ })
54+
55+ T .copy (Q [bid , hid * VALID_BLOCK_H :(hid + 1 ) * VALID_BLOCK_H , :], Q_shared )
56+ T .fill (acc_o , 0 )
57+ T .fill (logsum , 0 )
58+ T .fill (scores_max , - T .infinity (accum_dtype ))
59+
60+ loop_range = T .ceildiv ((seqlen_kv // num_split ), block_N )
61+ for k in T .Pipelined (loop_range , num_stages = 1 ):
62+ T .copy (
63+ K [bid , (seqlen_kv // num_split ) * sid +
64+ k * block_N :(seqlen_kv // num_split ) * sid + (k + 1 ) * block_N ,
65+ cur_kv_head , :], K_shared )
66+ T .clear (acc_s )
67+ T .gemm (Q_shared , K_shared , acc_s , transpose_B = True , policy = T .GemmWarpPolicy .FullRow )
68+ T .copy (scores_max , scores_max_prev )
69+ T .fill (scores_max , - T .infinity (accum_dtype ))
70+ T .reduce_max (acc_s , scores_max , dim = 1 , clear = False )
71+ for i in T .Parallel (block_H ):
72+ scores_scale [i ] = T .exp2 (scores_max_prev [i ] * scale - scores_max [i ] * scale )
73+ for i , j in T .Parallel (block_H , block_N ):
74+ acc_s [i , j ] = T .exp2 (acc_s [i , j ] * scale - scores_max [i ] * scale )
75+ T .reduce_sum (acc_s , scores_sum , dim = 1 )
76+ for i in T .Parallel (block_H ):
77+ logsum [i ] = logsum [i ] * scores_scale [i ] + scores_sum [i ]
78+ T .copy (acc_s , acc_s_cast )
79+ for i , j in T .Parallel (block_H , dim ):
80+ acc_o [i , j ] *= scores_scale [i ]
81+ T .copy (
82+ V [bid , (seqlen_kv // num_split ) * sid +
83+ k * block_N :(seqlen_kv // num_split ) * sid + (k + 1 ) * block_N ,
84+ cur_kv_head , :], V_shared )
85+ T .gemm (acc_s_cast , V_shared , acc_o , policy = T .GemmWarpPolicy .FullRow )
86+ for i , j in T .Parallel (block_H , dim ):
87+ acc_o [i , j ] /= logsum [i ]
88+ for i in T .Parallel (block_H ):
89+ logsum [i ] = T .log2 (logsum [i ]) + scores_max [i ] * scale
90+
91+ T .copy (logsum , glse [bid , hid * VALID_BLOCK_H :(hid + 1 ) * VALID_BLOCK_H , sid ])
92+ T .copy (acc_o , O_shared )
93+ T .copy (O_shared , Output_partial [bid , hid * VALID_BLOCK_H :(hid + 1 ) * VALID_BLOCK_H ,
94+ sid , :])
95+
96+ @T .macro
97+ def combine (
98+ glse : T .Buffer ([batch , heads , num_split ], dtype ),
99+ Output_partial : T .Buffer (part_shape , dtype ),
100+ Output : T .Buffer (shape_o , dtype ),
101+ ):
102+ with T .Kernel (heads , batch , threads = 128 ) as (by , bz ):
103+ po_local = T .alloc_fragment ([dim ], dtype )
104+ o_accum_local = T .alloc_fragment ([dim ], accum_dtype )
105+ lse_local = T .alloc_fragment ([num_split , 1 ], dtype )
106+ lse_local_split = T .alloc_local ([1 ], accum_dtype )
107+ lse_logsum_local = T .alloc_local ([1 ], accum_dtype )
108+ lse_max_local = T .alloc_fragment ([1 ], accum_dtype )
109+ scale_local = T .alloc_local ([1 ], accum_dtype )
110+
111+ T .annotate_layout ({
112+ lse_logsum_local : T .Fragment (lse_logsum_local .shape , forward_thread_fn = lambda i : i ),
113+ })
114+
115+ T .clear (lse_logsum_local )
116+ T .clear (o_accum_local )
117+ for k in T .Parallel (num_split ):
118+ lse_local [k , 0 ] = glse [bz , by , k ]
119+ T .reduce_max (lse_local , lse_max_local , dim = 0 , clear = True )
120+ for k in T .Pipelined (num_split , num_stages = 1 ):
121+ lse_local_split [0 ] = glse [bz , by , k ]
122+ lse_logsum_local [0 ] += T .exp2 (lse_local_split [0 ] - lse_max_local [0 ])
123+ lse_logsum_local [0 ] = T .log2 (lse_logsum_local [0 ]) + lse_max_local [0 ]
124+ for k in T .serial (num_split ):
125+ for i in T .Parallel (dim ):
126+ po_local [i ] = Output_partial [bz , by , k , i ]
127+ lse_local_split [0 ] = glse [bz , by , k ]
128+ scale_local [0 ] = T .exp2 (lse_local_split [0 ] - lse_logsum_local [0 ])
129+ for i in T .Parallel (dim ):
130+ o_accum_local [i ] += po_local [i ] * scale_local [0 ]
131+ for i in T .Parallel (dim ):
132+ Output [bz , by , i ] = o_accum_local [i ]
133+
134+ @T .prim_func
135+ def main (
136+ Q : T .Buffer (shape_q , dtype ),
137+ K : T .Buffer (shape_k , dtype ),
138+ V : T .Buffer (shape_v , dtype ),
139+ glse : T .Buffer ([batch , heads , num_split ], dtype ),
140+ Output_partial : T .Buffer (part_shape , dtype ), # [batch, heads, num_split, dim]
141+ Output : T .Buffer (shape_o , dtype ),
142+ ):
143+ flash_attn_split (Q , K , V , glse , Output_partial )
144+ combine (glse , Output_partial , Output )
145+
146+ return main
147+
148+
149+ def ref_program (query , key , value , glse , Output_partial ):
150+ # """
151+ # Inputs:
152+ # - query (Tensor): [batch, heads, dim]
153+ # - key (Tensor): [batch, seqlen_kv, kv_head_num, dim]
154+ # - value (Tensor): [batch, seqlen_kv, kv_head_num, dim]
155+
156+ # Outputs:
157+ # - output (Tensor): [batch, heads, dim]
158+ # """
159+ from einops import rearrange
160+ batch_size , query_heads , dim = query .shape # [batch_size, query_heads, dim]
161+ _ , seqlen_kv , kv_heads , _ = key .shape # [batch_size, seqlen_kv, kv_heads, kv_dim]
162+ dim_v = value .shape [- 1 ]
163+ assert kv_heads == 1 , "kv_heads must be 1"
164+
165+ query_expanded = rearrange (query , 'b h d -> b h 1 d' ) # [batch_size, query_heads, 1, dim]
166+ key_expanded = key .expand (- 1 , - 1 , query_heads , - 1 ) # [batch_size, query_heads, seqlen_kv, dim]
167+ value_expanded = value .expand (- 1 , - 1 , query_heads ,
168+ - 1 ) # [batch_size, query_heads, seqlen_kv, dim]
169+ key_expanded = rearrange (key_expanded ,
170+ 'b n h d -> b h n d' ) # [batch_size, kv_head_num, seqlen_kv, dim]
171+ value_expanded = rearrange (value_expanded ,
172+ 'b n h d -> b h n d' ) # [batch_size, query_heads, seqlen_kv, dim]
173+
174+ scores = torch .matmul (query_expanded ,
175+ key_expanded .transpose (- 1 , - 2 )) # [batch_size, query_heads, 1, seqlen_kv]
176+ scores = scores / torch .sqrt (torch .tensor (dim , dtype = scores .dtype ))
177+ attention_weights = F .softmax (scores , dim = - 1 ) # [batch_size, query_heads, 1, seqlen_kv]
178+ output = torch .matmul (attention_weights , value_expanded ) # [batch_size, query_heads, 1, dim]
179+ return output .view (batch_size , query_heads , dim_v )
180+
181+
182+ def flash_split_ref (Q , K , V ):
183+ dim = 512
184+ pe_dim = 64
185+ batch = Q .size (0 )
186+ nheads = Q .size (1 )
187+ assert Q .size (2 ) == dim + pe_dim , "dim must be 576=512+64"
188+ block_N = 32
189+ seqlen_kv = K .size (1 )
190+
191+ scale = (1.0 / (dim + pe_dim ))** 0.5 * 1.44269504 # log2(e)
192+ acc_s = torch .empty ((batch , nheads , block_N ), device = "cuda" , dtype = torch .float )
193+ acc_s_cast = torch .empty ((batch , nheads , block_N ), device = "cuda" , dtype = torch .float16 )
194+ acc_o = torch .empty ((batch , nheads , dim ), device = "cuda" , dtype = torch .float )
195+ scores_max = torch .empty ((batch , nheads ), device = "cuda" , dtype = torch .float )
196+ scores_max_prev = torch .empty ((batch , nheads ), device = "cuda" , dtype = torch .float )
197+ scores_scale = torch .empty ((batch , nheads ), device = "cuda" , dtype = torch .float )
198+ scores_sum = torch .empty ((batch , nheads ), device = "cuda" , dtype = torch .float )
199+ logsum = torch .empty ((batch , nheads ), device = "cuda" , dtype = torch .float )
200+ gacc_o = torch .empty ((num_split , batch , nheads , dim ), device = "cuda" , dtype = torch .float )
201+ glogsum = torch .empty ((num_split , batch , nheads ), device = "cuda" , dtype = torch .float )
202+
203+ Q_ = Q * scale
204+ K_ = K .expand (- 1 , - 1 , nheads , - 1 )
205+ V_ = V .expand (- 1 , - 1 , nheads , - 1 )
206+
207+ for ks in range (num_split ):
208+ acc_o .fill_ (0 )
209+ logsum .fill_ (0 )
210+ scores_max .fill_ (float ('-inf' ))
211+ scores_max_prev .fill_ (float ('-inf' ))
212+ for i in range (int ((seqlen_kv // num_split ) / block_N )):
213+ acc_s .fill_ (0 )
214+ acc_s = torch .einsum ('bhd,bkhd->bhk' , Q_ ,
215+ K_ [:, (seqlen_kv // num_split ) * ks +
216+ i * block_N :(seqlen_kv // num_split ) * ks +
217+ (i + 1 ) * block_N , :, :]) # [batch, nheads, block_N]
218+ scores_max_prev = scores_max
219+ scores_max = acc_s .max (dim = - 1 , keepdim = False ).values # [batch, nheads]
220+ scores_scale = torch .exp2 (scores_max_prev - scores_max ) # [batch, nheads]
221+ acc_o *= scores_scale [:, :, None ]
222+ acc_s = torch .exp2 (acc_s - scores_max [:, :, None ])
223+ acc_s_cast = acc_s .to (torch .float16 ) # [batch, nheads, block_N]
224+ acc_o += torch .einsum (
225+ 'bhk,bkhd->bhd' , acc_s_cast ,
226+ V_ [:, (seqlen_kv // num_split ) * ks + i * block_N :(seqlen_kv // num_split ) * ks +
227+ (i + 1 ) * block_N , :, :])
228+ scores_sum = acc_s .sum (dim = - 1 , keepdim = False )
229+ logsum = logsum * scores_scale + scores_sum
230+ acc_o /= logsum [:, :, None ]
231+ logsum = torch .log2 (logsum ) + scores_max
232+ gacc_o [ks , :, :, :] = acc_o
233+ glogsum [ks , :, :] = logsum
234+
235+ return glogsum .to (torch .float16 ).permute (1 , 2 , 0 ), gacc_o .to (torch .float16 ).permute (1 , 2 , 0 , 3 )
236+
237+
238+ def reduce_ref (Q , K , V , glse , Output_partial ):
239+ o = torch .empty_like (Output_partial [:, :, 0 , :]).fill_ (0 )
240+ lse_logsum = torch .empty_like (glse [:, :, 0 ]).fill_ (0 )
241+ lse_max = glse .max (dim = 2 , keepdim = False ).values
242+ for ks in range (num_split ):
243+ lse = glse [:, :, ks ]
244+ lse_logsum += torch .exp2 (lse - lse_max )
245+ lse_logsum = torch .log2 (lse_logsum ) + lse_max
246+ for ks in range (num_split ):
247+ lse = glse [:, :, ks ]
248+ scale = torch .exp2 (lse - lse_logsum )
249+ o += Output_partial [:, :, ks , :] * scale [:, :, None ]
250+ return o .to (torch .float16 )
251+
252+
253+ if __name__ == "__main__" :
254+ BATCH , H_Q , KV_H , KV_CTX , D_HEAD , DPE = 64 , 128 , 1 , 8192 , 512 , 64
255+ qk_flops = 2 * BATCH * H_Q * KV_CTX * (D_HEAD + DPE )
256+ pv_flops = 2 * BATCH * H_Q * KV_CTX * D_HEAD
257+ total_flops = qk_flops + pv_flops
258+ BLOCK_N = 32 # if D_HEAD <= 128 else 32
259+ BLOCK_H = 64
260+
261+ program = flashattn (BATCH , H_Q , KV_H , KV_CTX , D_HEAD , DPE , BLOCK_N , BLOCK_H )
262+ mod , params = tilelang .lower (program )
263+ mod = tilelang .Profiler (mod , params , [5 ], tilelang .TensorSupplyType .Normal )
264+ mod .assert_allclose (ref_program , rtol = 0.01 , atol = 0.01 )
265+ latency = mod .do_bench (mod .func , warmup = 500 )
266+ print ("Tile-lang: {:.2f} ms" .format (latency ))
267+ print ("Tile-lang: {:.2f} TFlops" .format (total_flops / latency * 1e-9 ))
0 commit comments