@@ -31,7 +31,6 @@ def flash_attn(
3131 K_pe_shared = T .alloc_shared ([block_N , pe_dim ], dtype )
3232 O_shared = T .alloc_shared ([block_H , dim ], dtype )
3333 acc_s = T .alloc_fragment ([block_H , block_N ], accum_dtype )
34- acc_s_0 = T .alloc_fragment ([block_H , block_N ], accum_dtype )
3534 acc_s_cast = T .alloc_fragment ([block_H , block_N ], dtype )
3635 acc_o = T .alloc_fragment ([block_H , dim ], accum_dtype )
3736 scores_max = T .alloc_fragment ([block_H ], accum_dtype )
@@ -57,28 +56,27 @@ def flash_attn(
5756 for k in T .Pipelined (loop_range , num_stages = 2 ):
5857 T .copy (KV [bx , k * block_N :(k + 1 ) * block_N , cur_kv_head , :], KV_shared )
5958 T .copy (K_pe [bx , k * block_N :(k + 1 ) * block_N , cur_kv_head , :], K_pe_shared )
60- T .clear (acc_s_0 )
59+ T .clear (acc_s )
6160 T .gemm (
62- Q_shared , KV_shared , acc_s_0 , transpose_B = True , policy = T .GemmWarpPolicy .FullCol )
61+ Q_shared , KV_shared , acc_s , transpose_B = True , policy = T .GemmWarpPolicy .FullCol )
6362 T .gemm (
6463 Q_pe_shared ,
6564 K_pe_shared ,
66- acc_s_0 ,
65+ acc_s ,
6766 transpose_B = True ,
6867 policy = T .GemmWarpPolicy .FullCol )
6968 T .copy (scores_max , scores_max_prev )
7069 T .fill (scores_max , - T .infinity (accum_dtype ))
71- T .copy (acc_s_0 , S_shared )
72- T .copy (S_shared , acc_s )
7370 T .reduce_max (acc_s , scores_max , dim = 1 , clear = False )
7471 for i in T .Parallel (block_H ):
7572 scores_scale [i ] = T .exp2 (scores_max_prev [i ] * scale - scores_max [i ] * scale )
7673 for i , j in T .Parallel (block_H , block_N ):
7774 acc_s [i , j ] = T .exp2 (acc_s [i , j ] * scale - scores_max [i ] * scale )
7875 T .reduce_sum (acc_s , scores_sum , dim = 1 )
76+ T .copy (acc_s , S_shared )
77+ T .copy (S_shared , acc_s_cast )
7978 for i in T .Parallel (block_H ):
8079 logsum [i ] = logsum [i ] * scores_scale [i ] + scores_sum [i ]
81- T .copy (acc_s , acc_s_cast )
8280 for i , j in T .Parallel (block_H , dim ):
8381 acc_o [i , j ] *= scores_scale [i ]
8482 T .gemm (acc_s_cast , KV_shared , acc_o , policy = T .GemmWarpPolicy .FullCol )
@@ -105,7 +103,6 @@ def flash_attn_split(
105103 K_pe_shared = T .alloc_shared ([block_N , pe_dim ], dtype )
106104 O_shared = T .alloc_shared ([block_H , dim ], dtype )
107105 acc_s = T .alloc_fragment ([block_H , block_N ], accum_dtype )
108- acc_s_0 = T .alloc_fragment ([block_H , block_N ], accum_dtype )
109106 acc_s_cast = T .alloc_fragment ([block_H , block_N ], dtype )
110107 acc_o = T .alloc_fragment ([block_H , dim ], accum_dtype )
111108 scores_max = T .alloc_fragment ([block_H ], accum_dtype )
@@ -131,31 +128,29 @@ def flash_attn_split(
131128 for k in T .Pipelined (loop_range , num_stages = 2 ):
132129 kv_start = (seqlen_kv // num_split ) * bz + k * block_N
133130 kv_end = (seqlen_kv // num_split ) * bz + (k + 1 ) * block_N
134-
135131 T .copy (KV [bx , kv_start :kv_end , cur_kv_head , :], KV_shared )
136132 T .copy (K_pe [bx , kv_start :kv_end , cur_kv_head , :], K_pe_shared )
137- T .clear (acc_s_0 )
133+ T .clear (acc_s )
138134 T .gemm (
139- Q_shared , KV_shared , acc_s_0 , transpose_B = True , policy = T .GemmWarpPolicy .FullCol )
135+ Q_shared , KV_shared , acc_s , transpose_B = True , policy = T .GemmWarpPolicy .FullCol )
140136 T .gemm (
141137 Q_pe_shared ,
142138 K_pe_shared ,
143- acc_s_0 ,
139+ acc_s ,
144140 transpose_B = True ,
145141 policy = T .GemmWarpPolicy .FullCol )
146142 T .copy (scores_max , scores_max_prev )
147143 T .fill (scores_max , - T .infinity (accum_dtype ))
148- T .copy (acc_s_0 , S_shared )
149- T .copy (S_shared , acc_s )
150144 T .reduce_max (acc_s , scores_max , dim = 1 , clear = False )
151145 for i in T .Parallel (block_H ):
152146 scores_scale [i ] = T .exp2 (scores_max_prev [i ] * scale - scores_max [i ] * scale )
153147 for i , j in T .Parallel (block_H , block_N ):
154148 acc_s [i , j ] = T .exp2 (acc_s [i , j ] * scale - scores_max [i ] * scale )
155149 T .reduce_sum (acc_s , scores_sum , dim = 1 )
150+ T .copy (acc_s , S_shared )
151+ T .copy (S_shared , acc_s_cast )
156152 for i in T .Parallel (block_H ):
157153 logsum [i ] = logsum [i ] * scores_scale [i ] + scores_sum [i ]
158- T .copy (acc_s , acc_s_cast )
159154 for i , j in T .Parallel (block_H , dim ):
160155 acc_o [i , j ] *= scores_scale [i ]
161156 T .gemm (acc_s_cast , KV_shared , acc_o , policy = T .GemmWarpPolicy .FullCol )
@@ -301,4 +296,4 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
301296 print ("All close" )
302297 latency = mod .do_bench (mod .func , n_warmup = 10 , n_repeat = 10 , profiler = "torch" )
303298 print ("Tile-lang: {:.2f} ms" .format (latency ))
304- print ("Tile-lang: {:.2f} TFlops" .format (total_flops / latency * 1e-9 ))
299+ print ("Tile-lang: {:.2f} TFlops" .format (total_flops / latency * 1e-9 ))
0 commit comments