11import torch
2- import tilelang as tl
2+ import tilelang
33import tilelang .language as T
44from tilelang .profiler import do_bench
55import argparse
99from typing import Optional , Tuple
1010
1111
12- @tl .jit (pass_configs = {
13- tl .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
14- tl .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
15- })
12+ @tilelang .jit (
13+ pass_configs = {
14+ tilelang .PassConfigKey .TL_DISABLE_TMA_LOWER : True ,
15+ tilelang .PassConfigKey .TL_DISABLE_WARP_SPECIALIZED : True ,
16+ })
1617def tl_fused_chunk_bwd_kernel (
1718 B ,
1819 S ,
@@ -30,12 +31,12 @@ def tl_fused_chunk_bwd_kernel(
3031 chunk_size = 64
3132 BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
3233 assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
33- NK = tl .cdiv (DK , BK )
34- NV = tl .cdiv (DV , BV )
35- NT = tl .cdiv (S , chunk_size )
34+ NK = tilelang .cdiv (DK , BK )
35+ NV = tilelang .cdiv (DV , BV )
36+ NT = tilelang .cdiv (S , chunk_size )
3637
3738 @T .prim_func
38- def chunk_linear_attn_bwd (
39+ def fused_chunk_linear_attn_bwd (
3940 Q : T .Tensor ([B , S , H , DK ], dtype ), # type: ignore
4041 K : T .Tensor ([B , S , H , DK ], dtype ), # type: ignore
4142 V : T .Tensor ([B , S , H , DV ], dtype ), # type: ignore
@@ -64,18 +65,19 @@ def chunk_linear_attn_bwd(
6465 h_shared = T .alloc_shared ([BV , BK ], dtype )
6566 dh = T .alloc_fragment ([BK , BV ], accum_dtype )
6667 dh_shared = T .alloc_shared ([BK , BV ], dtype )
67- T .clear (h )
68- T .clear (dh )
6968
7069 T .annotate_layout ({
71- dq_shared : tl .layout .make_swizzled_layout (dq_shared ),
72- dk_shared : tl .layout .make_swizzled_layout (dk_shared ),
73- dv_shared : tl .layout .make_swizzled_layout (dv_shared )
70+ dq_shared : tilelang .layout .make_swizzled_layout (dq_shared ),
71+ dk_shared : tilelang .layout .make_swizzled_layout (dk_shared ),
72+ dv_shared : tilelang .layout .make_swizzled_layout (dv_shared )
7473 })
7574 T .use_swizzle (10 )
7675
76+ T .clear (h )
77+ T .clear (dh )
78+
7779 # Calculate dQ
78- for i in T .Pipelined (0 , NT , num_stages = 1 ):
80+ for i in T .Pipelined (0 , NT ):
7981 T .copy (K [i_b , i * chunk_size :(i + 1 ) * chunk_size , i_h , i_k * BK :(i_k + 1 ) * BK ], k )
8082 T .copy (V [i_b , i * chunk_size :(i + 1 ) * chunk_size , i_h , i_v * BV :(i_v + 1 ) * BV ], v )
8183 T .copy (dO [i_b , i * chunk_size :(i + 1 ) * chunk_size , i_h , i_v * BV :(i_v + 1 ) * BV ],
@@ -97,7 +99,7 @@ def chunk_linear_attn_bwd(
9799 dq_shared )
98100
99101 # Calculate dK, dV (reversely)
100- for i in T .Pipelined (1 , NT + 1 , num_stages = 1 ):
102+ for i in T .Pipelined (1 , NT + 1 ):
101103 start = NT - i
102104 for row , col in T .Parallel (chunk_size , BK ):
103105 q [row , col ] = Q [i_b , start * chunk_size + row , i_h , i_k * BK + col ] * scale
@@ -139,9 +141,8 @@ def chunk_linear_attn_bwd(
139141 T .atomic_add (
140142 dV [i_b , start * chunk_size :(start + 1 ) * chunk_size , i_h ,
141143 i_v * BV :(i_v + 1 ) * BV ], dv_shared )
142- #TODO: consider using vectorized atomic add or tma reduce for sm90
143144
144- return chunk_linear_attn_bwd
145+ return fused_chunk_linear_attn_bwd
145146
146147
147148def tl_fused_chunk_bwd (Q , K , V , dO ):
@@ -188,6 +189,7 @@ def main(B=1, S=1024, H=16, D=128):
188189 k = l2norm_fwd (k )[0 ].requires_grad_ (True )
189190
190191 dq , dk , dv = tl_fused_chunk_bwd (q , k , v , do )
192+ q .grad = k .grad = v .grad = None
191193 o_ref , _ = ref_program (q , k , v )
192194 o_ref .backward (do , retain_graph = True )
193195
@@ -202,9 +204,8 @@ def main(B=1, S=1024, H=16, D=128):
202204 # Benchmark
203205 q .grad = k .grad = v .grad = None
204206 o_ref , _ = fused_chunk_linear_attn (q , k , v , output_final_state = True , normalize = False )
205- t1 = do_bench (
206- lambda : o_ref .backward (do , retain_graph = True ), warmup = 25 , rep = 100 , backend = 'cupti' )
207- t2 = do_bench (lambda : tl_fused_chunk_bwd (q , k , v , do ), warmup = 25 , rep = 100 , backend = 'cupti' )
207+ t1 = do_bench (lambda : o_ref .backward (do , retain_graph = True ), backend = 'cupti' )
208+ t2 = do_bench (lambda : tl_fused_chunk_bwd (q , k , v , do ), backend = 'cupti' )
208209 print (f'Triton latency: { t1 :.3f} ms' )
209210 print (f'TileLang latency: { t2 :.3f} ms' )
210211 print (f'Speedup: { t1 / t2 :.3f} x' )
0 commit comments