11import torch
2- import tilelang as tl
2+ import tilelang
33import tilelang .language as T
44from tilelang .profiler import do_bench
5-
65import argparse
76from fla .ops .linear_attn import fused_chunk_linear_attn # We compare with FLA
7+ from fla .modules .l2norm import l2norm_fwd
8+ from einops import rearrange
9+ from typing import Optional , Tuple
810
911
10- @tl .jit (
11- out_idx = [4 , 5 , 6 ],
12+ @tilelang .jit (
1213 pass_configs = {
13- "tl.disable_tma_lower" : True ,
14- "tl.disable_warp_specialized" : True
14+ tilelang . PassConfigKey . TL_DISABLE_TMA_LOWER : True ,
15+ tilelang . PassConfigKey . TL_DISABLE_WARP_SPECIALIZED : True ,
1516 })
16- def chunk_linear_attn_bwd_kernel (
17+ def tl_fused_chunk_bwd_kernel (
1718 B ,
1819 S ,
1920 H ,
@@ -30,19 +31,19 @@ def chunk_linear_attn_bwd_kernel(
3031 chunk_size = 64
3132 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
3233 assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
33- NK = tl .cdiv (DK , BK )
34- NV = tl .cdiv (DV , BV )
35- NT = tl .cdiv (S , chunk_size )
34+ NK = tilelang .cdiv (DK , BK )
35+ NV = tilelang .cdiv (DV , BV )
36+ NT = tilelang .cdiv (S , chunk_size )
3637
3738 @T .prim_func
38- def chunk_linear_attn_bwd (
39+ def fused_chunk_linear_attn_bwd (
3940 Q : T .Tensor ([B , S , H , DK ], dtype ), # type: ignore
4041 K : T .Tensor ([B , S , H , DK ], dtype ), # type: ignore
4142 V : T .Tensor ([B , S , H , DV ], dtype ), # type: ignore
4243 dO : T .Tensor ([B , S , H , DV ], dtype ), # type: ignore
43- dQ : T .Tensor ([NV , B , S , H , DK ], dtype ), # type: ignore
44- dK : T .Tensor ([NV , B , S , H , DK ], dtype ), # type: ignore
45- dV : T .Tensor ([NK , B , S , H , DV ], dtype ), # type: ignore
44+ dQ : T .Tensor ([B , S , H , DK ], accum_dtype ), # type: ignore
45+ dK : T .Tensor ([B , S , H , DK ], accum_dtype ), # type: ignore
46+ dV : T .Tensor ([B , S , H , DV ], accum_dtype ), # type: ignore
4647 ):
4748 with T .Kernel (NV , NK , B * H ) as (i_v , i_k , i_bh ):
4849 i_b = i_bh // H
@@ -51,8 +52,11 @@ def chunk_linear_attn_bwd(
5152 ds = T .alloc_fragment ([chunk_size , chunk_size ], accum_dtype )
5253 ds_shared = T .alloc_shared ([chunk_size , chunk_size ], dtype )
5354 dq = T .alloc_fragment ([chunk_size , BK ], accum_dtype )
55+ dq_shared = T .alloc_shared ([chunk_size , BK ], accum_dtype )
5456 dk = T .alloc_fragment ([chunk_size , BK ], accum_dtype )
57+ dk_shared = T .alloc_shared ([chunk_size , BK ], accum_dtype )
5558 dv = T .alloc_fragment ([chunk_size , BV ], accum_dtype )
59+ dv_shared = T .alloc_shared ([chunk_size , BV ], accum_dtype )
5660 q = T .alloc_shared ([chunk_size , BK ], dtype )
5761 k = T .alloc_shared ([chunk_size , BK ], dtype )
5862 v = T .alloc_shared ([chunk_size , BV ], dtype )
@@ -61,22 +65,19 @@ def chunk_linear_attn_bwd(
6165 h_shared = T .alloc_shared ([BV , BK ], dtype )
6266 dh = T .alloc_fragment ([BK , BV ], accum_dtype )
6367 dh_shared = T .alloc_shared ([BK , BV ], dtype )
64- T .clear (h )
65- T .clear (dh )
6668
6769 T .annotate_layout ({
68- ds_shared : tl .layout .make_swizzled_layout (ds_shared ),
69- q : tl .layout .make_swizzled_layout (q ),
70- k : tl .layout .make_swizzled_layout (k ),
71- v : tl .layout .make_swizzled_layout (v ),
72- do : tl .layout .make_swizzled_layout (do ),
73- h_shared : tl .layout .make_swizzled_layout (h_shared ),
74- dh_shared : tl .layout .make_swizzled_layout (dh_shared )
70+ dq_shared : tilelang .layout .make_swizzled_layout (dq_shared ),
71+ dk_shared : tilelang .layout .make_swizzled_layout (dk_shared ),
72+ dv_shared : tilelang .layout .make_swizzled_layout (dv_shared )
7573 })
7674 T .use_swizzle (10 )
7775
76+ T .clear (h )
77+ T .clear (dh )
78+
7879 # Calculate dQ
79- for i in T .Pipelined (0 , NT , num_stages = 1 ):
80+ for i in T .Pipelined (0 , NT ):
8081 T .copy (K [i_b , i * chunk_size :(i + 1 ) * chunk_size , i_h , i_k * BK :(i_k + 1 ) * BK ], k )
8182 T .copy (V [i_b , i * chunk_size :(i + 1 ) * chunk_size , i_h , i_v * BV :(i_v + 1 ) * BV ], v )
8283 T .copy (dO [i_b , i * chunk_size :(i + 1 ) * chunk_size , i_h , i_v * BV :(i_v + 1 ) * BV ],
@@ -92,12 +93,13 @@ def chunk_linear_attn_bwd(
9293 T .gemm (v , k , h , transpose_A = True )
9394 for row , col in T .Parallel (chunk_size , BK ):
9495 dq [row , col ] *= scale
95- T .copy (
96- dq , dQ [i_v , i_b , i * chunk_size :(i + 1 ) * chunk_size , i_h ,
97- i_k * BK :(i_k + 1 ) * BK ])
96+ T .copy (dq , dq_shared )
97+ T .atomic_add (
98+ dQ [i_b , i * chunk_size :(i + 1 ) * chunk_size , i_h , i_k * BK :(i_k + 1 ) * BK ],
99+ dq_shared )
98100
99101 # Calculate dK, dV (reversely)
100- for i in T .Pipelined (1 , NT + 1 , num_stages = 1 ):
102+ for i in T .Pipelined (1 , NT + 1 ):
101103 start = NT - i
102104 for row , col in T .Parallel (chunk_size , BK ):
103105 q [row , col ] = Q [i_b , start * chunk_size + row , i_h , i_k * BK + col ] * scale
@@ -131,53 +133,90 @@ def chunk_linear_attn_bwd(
131133 # Update dh
132134 T .gemm (q , do , dh , transpose_A = True )
133135
134- T .copy (
135- dk , dK [i_v , i_b , start * chunk_size :(start + 1 ) * chunk_size , i_h ,
136- i_k * BK :(i_k + 1 ) * BK ])
137- T .copy (
138- dv , dV [i_k , i_b , start * chunk_size :(start + 1 ) * chunk_size , i_h ,
139- i_v * BV :(i_v + 1 ) * BV ])
140-
141- return chunk_linear_attn_bwd
142-
143-
144- def postprocess (dQ , dK , dV ):
145- dQ = dQ [0 ] if dQ .size (0 ) == 1 else dQ .sum (0 )
146- dK = dK [0 ] if dK .size (0 ) == 1 else dK .sum (0 )
147- dV = dV [0 ] if dV .size (0 ) == 1 else dV .sum (0 )
148- return dQ , dK , dV
149-
150-
151- def main ():
152- parser = argparse .ArgumentParser ()
153- parser .add_argument ('--B' , type = int , default = 8 , help = 'Batch size' )
154- parser .add_argument ('--S' , type = int , default = 4096 , help = 'Seq len' )
155- parser .add_argument ('--H' , type = int , default = 32 , help = 'Num heads' )
156- parser .add_argument ('--D' , type = int , default = 256 , help = 'Head dim' )
157- args = parser .parse_args ()
158- B , S , H , D = args .B , args .S , args .H , args .D
159-
136+ T .copy (dk , dk_shared )
137+ T .atomic_add (
138+ dK [i_b , start * chunk_size :(start + 1 ) * chunk_size , i_h ,
139+ i_k * BK :(i_k + 1 ) * BK ], dk_shared )
140+ T .copy (dv , dv_shared )
141+ T .atomic_add (
142+ dV [i_b , start * chunk_size :(start + 1 ) * chunk_size , i_h ,
143+ i_v * BV :(i_v + 1 ) * BV ], dv_shared )
144+
145+ return fused_chunk_linear_attn_bwd
146+
147+
148+ def tl_fused_chunk_bwd (Q , K , V , dO ):
149+ B , S , H , D = Q .shape
150+ kernel = tl_fused_chunk_bwd_kernel (B , S , H , D , D )
151+ dQ = torch .zeros_like (Q , dtype = torch .float32 )
152+ dK = torch .zeros_like (K , dtype = torch .float32 )
153+ dV = torch .zeros_like (V , dtype = torch .float32 )
154+ kernel (Q , K , V , dO , dQ , dK , dV )
155+ return dQ .to (torch .float16 ), dK .to (torch .float16 ), dV .to (torch .float16 )
156+
157+
158+ def ref_program (q : torch .Tensor ,
159+ k : torch .Tensor ,
160+ v : torch .Tensor ,
161+ scale : Optional [float ] = None ) -> Tuple [torch .Tensor , torch .Tensor ]:
162+ q , k , v = q .float (), k .float (), v .float ()
163+ if scale is None :
164+ scale = q .shape [- 1 ]** - 0.5
165+ chunk_size = 64
166+ q = rearrange (q , 'b (n c) h d -> b h n c d' , c = chunk_size ) * scale
167+ k = rearrange (k , 'b (n c) h d -> b h n c d' , c = chunk_size )
168+ v = rearrange (v , 'b (n c) h d -> b h n c d' , c = chunk_size )
169+ kv = k .transpose (- 1 , - 2 ) @ v
170+ kv = kv .cumsum (2 )
171+ h = kv [:, :, - 1 , :, :]
172+ kv = torch .cat ([torch .zeros_like (kv [:, :, :1 ]), kv [:, :, :- 1 ]], dim = 2 )
173+ inter = q @ kv
174+ intra = ((q @ k .transpose (- 1 , - 2 )).masked_fill_ (
175+ torch .triu (torch .ones (chunk_size , chunk_size , dtype = bool , device = q .device ), diagonal = 1 ),
176+ 0 )) @ v
177+ o = inter + intra
178+ return rearrange (o , 'b h n c d -> b (n c) h d' ), h
179+
180+
181+ def main (B = 1 , S = 1024 , H = 16 , D = 128 ):
160182 q = torch .randn ((B , S , H , D ), device = 'cuda' , dtype = torch .float16 , requires_grad = True )
161183 k = torch .randn ((B , S , H , D ), device = 'cuda' , dtype = torch .float16 , requires_grad = True )
162184 v = torch .randn ((B , S , H , D ), device = 'cuda' , dtype = torch .float16 , requires_grad = True )
163185 do = torch .randn ((B , S , H , D ), device = 'cuda' , dtype = torch .float16 )
164186
165- kernel = chunk_linear_attn_bwd_kernel (B , S , H , D , D )
166- dq , dk , dv = postprocess (* kernel (q , k , v , do ))
167- o_ref , _ = fused_chunk_linear_attn (q , k , v , output_final_state = True , normalize = False )
187+ # qk norm is necessary for linear attn
188+ q = l2norm_fwd (q )[0 ].requires_grad_ (True )
189+ k = l2norm_fwd (k )[0 ].requires_grad_ (True )
190+
191+ dq , dk , dv = tl_fused_chunk_bwd (q , k , v , do )
192+ q .grad = k .grad = v .grad = None
193+ o_ref , _ = ref_program (q , k , v )
168194 o_ref .backward (do , retain_graph = True )
169- if torch .allclose (dq , q .grad ) and torch .allclose (dk , k .grad ) and torch .allclose (dv , v .grad ):
170- print ('Passed all tests!✅' )
171- else :
172- print ('Failed some tests!❌' )
173- t1 = do_bench (lambda : o_ref .backward (do , retain_graph = True ), warmup = 25 , rep = 100 )
195+
196+ assert torch .allclose (
197+ dq , q .grad , atol = 1e-2 , rtol = 1e-2 ), f'dq max err: { (dq - q .grad ).abs ().max ()} '
198+ assert torch .allclose (
199+ dk , k .grad , atol = 1e-2 , rtol = 1e-2 ), f'dk max err: { (dk - k .grad ).abs ().max ()} '
200+ assert torch .allclose (
201+ dv , v .grad , atol = 1e-2 , rtol = 1e-2 ), f'dv max err: { (dv - v .grad ).abs ().max ()} '
202+ print ('Passed all tests!✅' )
203+
204+ # Benchmark
174205 q .grad = k .grad = v .grad = None
175206 o_ref , _ = fused_chunk_linear_attn (q , k , v , output_final_state = True , normalize = False )
176- t2 = do_bench (lambda : postprocess (* kernel (q , k , v , do )), warmup = 25 , rep = 100 )
207+ t1 = do_bench (lambda : o_ref .backward (do , retain_graph = True ), backend = 'cupti' )
208+ t2 = do_bench (lambda : tl_fused_chunk_bwd (q , k , v , do ), backend = 'cupti' )
177209 print (f'Triton latency: { t1 :.3f} ms' )
178210 print (f'TileLang latency: { t2 :.3f} ms' )
179211 print (f'Speedup: { t1 / t2 :.3f} x' )
180212
181213
182214if __name__ == '__main__' :
183- main ()
215+ parser = argparse .ArgumentParser ()
216+ parser .add_argument ('--B' , type = int , default = 8 , help = 'Batch size' )
217+ parser .add_argument ('--S' , type = int , default = 1024 , help = 'Seq len' )
218+ parser .add_argument ('--H' , type = int , default = 32 , help = 'Num heads' )
219+ parser .add_argument ('--D' , type = int , default = 128 , help = 'Head dim' )
220+ args = parser .parse_args ()
221+
222+ main (args .B , args .S , args .H , args .D )
0 commit comments