Skip to content

Commit 5ded38c

Browse files
committed
chore
1 parent b6fe638 commit 5ded38c

File tree

2 files changed

+36
-35
lines changed

2 files changed

+36
-35
lines changed

examples/linear_attention/example_linear_attn_bwd.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
import tilelang as tl
2+
import tilelang
33
import tilelang.language as T
44
from tilelang.profiler import do_bench
55
import argparse
@@ -9,10 +9,11 @@
99
from typing import Optional, Tuple
1010

1111

12-
@tl.jit(pass_configs={
13-
tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
14-
tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
15-
})
12+
@tilelang.jit(
13+
pass_configs={
14+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
15+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
16+
})
1617
def tl_fused_chunk_bwd_kernel(
1718
B,
1819
S,
@@ -30,12 +31,12 @@ def tl_fused_chunk_bwd_kernel(
3031
chunk_size = 64
3132
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
3233
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
33-
NK = tl.cdiv(DK, BK)
34-
NV = tl.cdiv(DV, BV)
35-
NT = tl.cdiv(S, chunk_size)
34+
NK = tilelang.cdiv(DK, BK)
35+
NV = tilelang.cdiv(DV, BV)
36+
NT = tilelang.cdiv(S, chunk_size)
3637

3738
@T.prim_func
38-
def chunk_linear_attn_bwd(
39+
def fused_chunk_linear_attn_bwd(
3940
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
4041
K: T.Tensor([B, S, H, DK], dtype), # type: ignore
4142
V: T.Tensor([B, S, H, DV], dtype), # type: ignore
@@ -64,18 +65,19 @@ def chunk_linear_attn_bwd(
6465
h_shared = T.alloc_shared([BV, BK], dtype)
6566
dh = T.alloc_fragment([BK, BV], accum_dtype)
6667
dh_shared = T.alloc_shared([BK, BV], dtype)
67-
T.clear(h)
68-
T.clear(dh)
6968

7069
T.annotate_layout({
71-
dq_shared: tl.layout.make_swizzled_layout(dq_shared),
72-
dk_shared: tl.layout.make_swizzled_layout(dk_shared),
73-
dv_shared: tl.layout.make_swizzled_layout(dv_shared)
70+
dq_shared: tilelang.layout.make_swizzled_layout(dq_shared),
71+
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
72+
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared)
7473
})
7574
T.use_swizzle(10)
7675

76+
T.clear(h)
77+
T.clear(dh)
78+
7779
# Calculate dQ
78-
for i in T.Pipelined(0, NT, num_stages=1):
80+
for i in T.Pipelined(0, NT):
7981
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
8082
T.copy(V[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV], v)
8183
T.copy(dO[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_v * BV:(i_v + 1) * BV],
@@ -97,7 +99,7 @@ def chunk_linear_attn_bwd(
9799
dq_shared)
98100

99101
# Calculate dK, dV (reversely)
100-
for i in T.Pipelined(1, NT + 1, num_stages=1):
102+
for i in T.Pipelined(1, NT + 1):
101103
start = NT - i
102104
for row, col in T.Parallel(chunk_size, BK):
103105
q[row, col] = Q[i_b, start * chunk_size + row, i_h, i_k * BK + col] * scale
@@ -139,9 +141,8 @@ def chunk_linear_attn_bwd(
139141
T.atomic_add(
140142
dV[i_b, start * chunk_size:(start + 1) * chunk_size, i_h,
141143
i_v * BV:(i_v + 1) * BV], dv_shared)
142-
#TODO: consider using vectorized atomic add or tma reduce for sm90
143144

144-
return chunk_linear_attn_bwd
145+
return fused_chunk_linear_attn_bwd
145146

146147

147148
def tl_fused_chunk_bwd(Q, K, V, dO):
@@ -188,6 +189,7 @@ def main(B=1, S=1024, H=16, D=128):
188189
k = l2norm_fwd(k)[0].requires_grad_(True)
189190

190191
dq, dk, dv = tl_fused_chunk_bwd(q, k, v, do)
192+
q.grad = k.grad = v.grad = None
191193
o_ref, _ = ref_program(q, k, v)
192194
o_ref.backward(do, retain_graph=True)
193195

@@ -202,9 +204,8 @@ def main(B=1, S=1024, H=16, D=128):
202204
# Benchmark
203205
q.grad = k.grad = v.grad = None
204206
o_ref, _ = fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False)
205-
t1 = do_bench(
206-
lambda: o_ref.backward(do, retain_graph=True), warmup=25, rep=100, backend='cupti')
207-
t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), warmup=25, rep=100, backend='cupti')
207+
t1 = do_bench(lambda: o_ref.backward(do, retain_graph=True), backend='cupti')
208+
t2 = do_bench(lambda: tl_fused_chunk_bwd(q, k, v, do), backend='cupti')
208209
print(f'Triton latency: {t1:.3f} ms')
209210
print(f'TileLang latency: {t2:.3f} ms')
210211
print(f'Speedup: {t1/t2:.3f}x')

examples/linear_attention/example_linear_attn_fwd.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
import tilelang as tl
2+
import tilelang
33
import tilelang.language as T
44
from tilelang.profiler import do_bench
55
import argparse
@@ -9,11 +9,11 @@
99
from typing import Optional, Tuple
1010

1111

12-
@tl.jit(
12+
@tilelang.jit(
1313
out_idx=[4],
1414
pass_configs={
15-
tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
16-
tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
15+
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
16+
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
1717
})
1818
def tl_fused_chunk_fwd_kernel(
1919
B,
@@ -32,12 +32,12 @@ def tl_fused_chunk_fwd_kernel(
3232
chunk_size = 64
3333
BK = BV = 64 # Set to 128 can be faster, but has some numerical differences with FLA
3434
assert S % chunk_size == 0 and DK % BK == 0 and DV % BV == 0
35-
NK = tl.cdiv(DK, BK)
36-
NV = tl.cdiv(DV, BV)
37-
NT = tl.cdiv(S, chunk_size)
35+
NK = tilelang.cdiv(DK, BK)
36+
NV = tilelang.cdiv(DV, BV)
37+
NT = tilelang.cdiv(S, chunk_size)
3838

3939
@T.prim_func
40-
def chunk_linear_attn_fwd(
40+
def fused_chunk_linear_attn_fwd(
4141
Q: T.Tensor([B, S, H, DK], dtype), # type: ignore
4242
K: T.Tensor([B, S, H, DK], dtype), # type: ignore
4343
V: T.Tensor([B, S, H, DV], dtype), # type: ignore
@@ -56,11 +56,13 @@ def chunk_linear_attn_fwd(
5656
s_shared = T.alloc_shared([chunk_size, chunk_size], dtype)
5757
o = T.alloc_fragment([chunk_size, BV], accum_dtype)
5858
o_shared = T.alloc_shared([chunk_size, BV], accum_dtype)
59-
T.clear(h)
6059

60+
T.annotate_layout({o_shared: tilelang.layout.make_swizzled_layout(o_shared)})
6161
T.use_swizzle(10)
6262

63-
for i in T.Pipelined(0, NT, num_stages=1):
63+
T.clear(h)
64+
65+
for i in T.Pipelined(0, NT):
6466
for row, col in T.Parallel(chunk_size, BK):
6567
q[row, col] = Q[i_b, i * chunk_size + row, i_h, i_k * BK + col] * scale
6668
T.copy(K[i_b, i * chunk_size:(i + 1) * chunk_size, i_h, i_k * BK:(i_k + 1) * BK], k)
@@ -83,7 +85,7 @@ def chunk_linear_attn_fwd(
8385
# Output final state
8486
T.copy(h, final_state[i_b, i_h, i_k * BK:(i_k + 1) * BK, i_v * BV:(i_v + 1) * BV])
8587

86-
return chunk_linear_attn_fwd
88+
return fused_chunk_linear_attn_fwd
8789

8890

8991
def tl_fused_chunk_fwd(q, k, v):
@@ -135,10 +137,8 @@ def main(B=1, S=512, H=16, D=128):
135137

136138
t1 = do_bench(
137139
lambda: fused_chunk_linear_attn(q, k, v, output_final_state=True, normalize=False),
138-
warmup=25,
139-
rep=100,
140140
backend='cupti')
141-
t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), warmup=25, rep=100, backend='cupti')
141+
t2 = do_bench(lambda: tl_fused_chunk_fwd(q, k, v), backend='cupti')
142142
print(f'Triton latency: {t1:.3f} ms')
143143
print(f'TileLang latency: {t2:.3f} ms')
144144
print(f'Speedup: {t1/t2:.3f}x')

0 commit comments

Comments
 (0)