Skip to content

Commit 1b61cff

Browse files
authored
[Dev] Add mha backward example (#77)
* [CI][Test] Add test cases for tilelang transform MultiVersionBuffer and WarpSpecialized * Relax the mismatch ratio restrictions in the flash_linear_attention and mha tests * [Dev] Add mha backward example
1 parent b427ec4 commit 1b61cff

File tree

2 files changed

+349
-0
lines changed

2 files changed

+349
-0
lines changed
Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import torch
5+
import torch.nn.functional as F
6+
import tilelang
7+
from tilelang.profiler import cached
8+
from tilelang.autotuner import *
9+
import tilelang.language as T
10+
import argparse
11+
12+
13+
def flashattn_fwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
14+
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
15+
shape = [batch, seq_len, heads, dim]
16+
dtype = "float16"
17+
accum_dtype = "float"
18+
19+
@T.prim_func
20+
def flash_fwd(
21+
Q: T.Buffer(shape, dtype), # type: ignore
22+
K: T.Buffer(shape, dtype), # type: ignore
23+
V: T.Buffer(shape, dtype), # type: ignore
24+
Output: T.Buffer(shape, dtype), # type: ignore
25+
lse: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore
26+
):
27+
with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz):
28+
Q_shared = T.alloc_shared([block_M, dim], dtype)
29+
# Q_local = T.alloc_fragment([block_M, dim], dtype)
30+
K_shared = T.alloc_shared([block_N, dim], dtype)
31+
V_shared = T.alloc_shared([block_N, dim], dtype)
32+
acc_s = T.alloc_fragment([block_M, block_N], accum_dtype)
33+
acc_s_cast = T.alloc_fragment([block_M, block_N], dtype)
34+
acc_o = T.alloc_fragment([block_M, dim], accum_dtype)
35+
scores_max = T.alloc_fragment([block_M], accum_dtype)
36+
scores_max_prev = T.alloc_fragment([block_M], accum_dtype)
37+
scores_scale = T.alloc_fragment([block_M], accum_dtype)
38+
scores_sum = T.alloc_fragment([block_M], accum_dtype)
39+
logsum = T.alloc_fragment([block_M], accum_dtype)
40+
41+
T.annotate_layout({Q_shared: tilelang.layout.make_swizzled_layout(Q_shared)})
42+
T.copy(Q[bz, bx * block_M:(bx + 1) * block_M, by, :], Q_shared)
43+
T.fill(acc_o, 0)
44+
T.fill(logsum, 0)
45+
T.fill(scores_max, -T.infinity(accum_dtype))
46+
# T.copy(Q_shared, Q_local)
47+
# for i, j in T.Parallel(block_M, dim):
48+
# Q_local[i, j] *= scale
49+
loop_range = (
50+
T.ceildiv(
51+
(bx + 1) * block_M, block_N) if is_casual else T.ceildiv(seq_len, block_N))
52+
for k in T.Pipelined(loop_range, num_stages=1):
53+
T.copy(K[bz, k * block_N:(k + 1) * block_N, by, :], K_shared)
54+
if is_casual:
55+
for i, j in T.Parallel(block_M, block_N):
56+
acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0,
57+
-T.infinity(acc_s.dtype))
58+
else:
59+
T.clear(acc_s)
60+
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
61+
T.copy(V[bz, k * block_N:(k + 1) * block_N, by, :], V_shared)
62+
T.copy(scores_max, scores_max_prev)
63+
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
64+
for i in T.Parallel(block_M):
65+
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
66+
for i, j in T.Parallel(block_M, dim):
67+
acc_o[i, j] *= scores_scale[i]
68+
for i, j in T.Parallel(block_M, block_N):
69+
acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale)
70+
T.copy(acc_s, acc_s_cast)
71+
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
72+
T.reduce_sum(acc_s, scores_sum, dim=1)
73+
for i in T.Parallel(block_M):
74+
logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i]
75+
for i, j in T.Parallel(block_M, dim):
76+
acc_o[i, j] /= logsum[i]
77+
T.copy(acc_o, Output[bz, bx * block_M:(bx + 1) * block_M, by, :])
78+
for i in T.Parallel(block_M):
79+
logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale
80+
T.copy(logsum, lse[bz, by, bx * block_M:(bx + 1) * block_M])
81+
82+
return flash_fwd
83+
84+
85+
def flashattn_bwd_preprocess(batch, heads, seq_len, dim):
86+
dtype = "float16"
87+
accum_dtype = "float"
88+
shape = [batch, seq_len, heads, dim]
89+
blk = 32
90+
91+
@T.prim_func
92+
def flash_bwd_prep(
93+
O: T.Buffer(shape, dtype), # type: ignore
94+
dO: T.Buffer(shape, dtype), # type: ignore
95+
Delta: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore
96+
):
97+
with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz):
98+
o = T.alloc_fragment([blk, blk], dtype)
99+
do = T.alloc_fragment([blk, blk], dtype)
100+
acc = T.alloc_fragment([blk, blk], accum_dtype)
101+
delta = T.alloc_fragment([blk], accum_dtype)
102+
T.clear(acc)
103+
for k in range(T.ceildiv(dim, blk)):
104+
T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
105+
T.copy(dO[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], do)
106+
for i, j in T.Parallel(blk, blk):
107+
acc[i, j] += o[i, j] * do[i, j]
108+
T.reduce_sum(acc, delta, 1)
109+
T.copy(delta, Delta[bz, bx, by * blk:(by + 1) * blk])
110+
111+
return flash_bwd_prep
112+
113+
114+
def make_dq_layout(dQ):
115+
# atomicAdd can not be vectorized, so we need to reorder dq to match the 8x8 gemm fragment
116+
return T.Layout(dQ.shape,
117+
lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2])
118+
119+
120+
def flashattn_bwd_postprocess(batch, heads, seq_len, dim):
121+
dtype = "float16"
122+
accum_dtype = "float"
123+
shape = [batch, seq_len, heads, dim]
124+
blk = 64
125+
126+
@T.prim_func
127+
def flash_bwd_post(
128+
dQ: T.Buffer(shape, accum_dtype), # type: ignore
129+
dQ_out: T.Buffer(shape, dtype), # type: ignore
130+
):
131+
with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz):
132+
T.annotate_layout({dQ: make_dq_layout(dQ)})
133+
T.copy(
134+
dQ[bz, bx * blk:(bx + 1) * blk, by, :],
135+
dQ_out[bz, bx * blk:(bx + 1) * blk, by, :],
136+
)
137+
138+
return flash_bwd_post
139+
140+
141+
def flashattn_bwd(batch, heads, seq_len, dim, is_casual, block_M, block_N):
142+
sm_scale = (1.0 / dim)**0.5
143+
scale = (1.0 / dim)**0.5 * 1.44269504 # log2(e)
144+
shape = [batch, seq_len, heads, dim]
145+
dtype = "float16"
146+
accum_dtype = "float"
147+
148+
@T.prim_func
149+
def flash_bwd(
150+
Q: T.Buffer(shape, dtype), # type: ignore
151+
K: T.Buffer(shape, dtype), # type: ignore
152+
V: T.Buffer(shape, dtype), # type: ignore
153+
dO: T.Buffer(shape, dtype), # type: ignore
154+
lse: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore
155+
Delta: T.Buffer([batch, heads, seq_len], accum_dtype), # type: ignore
156+
dQ: T.Buffer(shape, accum_dtype), # type: ignore
157+
dK: T.Buffer(shape, dtype), # type: ignore
158+
dV: T.Buffer(shape, dtype), # type: ignore
159+
):
160+
with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz):
161+
K_shared = T.alloc_shared([block_M, dim], dtype)
162+
dsT_shared = T.alloc_shared([block_M, block_N], dtype)
163+
# should not store K to local if dim is large
164+
# K_local = T.alloc_fragment([block_M, dim], dtype)
165+
# K_local_T = T.alloc_fragment([block_M, dim], dtype)
166+
# V_local = T.alloc_fragment([block_M, dim], dtype)
167+
q = T.alloc_shared([block_N, dim], dtype)
168+
V_shared = T.alloc_shared([block_M, dim], dtype)
169+
qkT = T.alloc_fragment([block_M, block_N], accum_dtype)
170+
dsT = T.alloc_fragment([block_M, block_N], accum_dtype)
171+
qkT_cast = T.alloc_fragment([block_M, block_N], dtype)
172+
dsT_cast = T.alloc_fragment([block_M, block_N], dtype)
173+
lse_shared = T.alloc_shared([block_N], accum_dtype)
174+
delta = T.alloc_shared([block_N], accum_dtype)
175+
do = T.alloc_shared([block_N, dim], dtype)
176+
dv = T.alloc_fragment([block_M, dim], accum_dtype)
177+
dk = T.alloc_fragment([block_M, dim], accum_dtype)
178+
dq = T.alloc_fragment([block_N, dim], accum_dtype)
179+
dv_shared = T.alloc_shared([block_N, dim], dtype)
180+
dk_shared = T.alloc_shared([block_N, dim], dtype)
181+
182+
T.annotate_layout({
183+
dQ: make_dq_layout(dQ),
184+
K_shared: tilelang.layout.make_swizzled_layout(K_shared),
185+
dv_shared: tilelang.layout.make_swizzled_layout(dv_shared),
186+
dk_shared: tilelang.layout.make_swizzled_layout(dk_shared),
187+
})
188+
189+
T.copy(K[bz, by * block_M:(by + 1) * block_M, bx, :], K_shared)
190+
T.copy(V[bz, by * block_M:(by + 1) * block_M, bx, :], V_shared)
191+
T.clear(dv)
192+
T.clear(dk)
193+
loop_st = T.floordiv(by * block_M, block_N) if is_casual else 0
194+
loop_ed = T.ceildiv(seq_len, block_N)
195+
for k in T.Pipelined(loop_st, loop_ed, num_stages=2):
196+
T.copy(Q[bz, k * block_N:(k + 1) * block_N, bx, :], q)
197+
T.clear(qkT)
198+
T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
199+
T.copy(lse[bz, bx, k * block_N:(k + 1) * block_N], lse_shared)
200+
for i, j in T.Parallel(block_M, block_N):
201+
qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j])
202+
if is_casual:
203+
for i, j in T.Parallel(block_M, block_N):
204+
qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j],
205+
0)
206+
T.copy(dO[bz, k * block_N:(k + 1) * block_N, bx, :], do)
207+
T.clear(dsT)
208+
T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
209+
T.copy(qkT, qkT_cast)
210+
T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow)
211+
212+
T.copy(Delta[bz, bx, k * block_N:(k + 1) * block_N], delta)
213+
214+
for i, j in T.Parallel(block_M, block_N):
215+
dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale
216+
T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow)
217+
218+
T.copy(dsT_cast, dsT_shared)
219+
T.clear(dq)
220+
T.gemm(dsT_shared, K_shared, dq, transpose_A=True)
221+
for i, j in T.Parallel(block_N, dim):
222+
if k * block_N + i < seq_len:
223+
T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j])
224+
T.copy(dv, dv_shared)
225+
T.copy(dk, dk_shared)
226+
T.copy(dv_shared, dV[bz, by * block_M:(by + 1) * block_M, bx, :])
227+
T.copy(dk_shared, dK[bz, by * block_M:(by + 1) * block_M, bx, :])
228+
229+
return flash_bwd
230+
231+
232+
class _attention(torch.autograd.Function):
233+
234+
@staticmethod
235+
def forward(ctx, q, k, v, causal):
236+
BATCH, N_CTX, H, D_HEAD = q.shape
237+
block_M = 64
238+
block_N = 64 if D_HEAD <= 128 else 32
239+
mod = cached(flashattn_fwd, [3, 4], BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N)
240+
o, lse = mod(q, k, v)
241+
ctx.save_for_backward(q, k, v, o, lse)
242+
ctx.causal = causal
243+
return o
244+
245+
@staticmethod
246+
def backward(ctx, do):
247+
q, k, v, o, lse = ctx.saved_tensors
248+
249+
def maybe_contiguous(x):
250+
if x.stride(-1) != 1:
251+
return x.contiguous()
252+
return x
253+
254+
do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)]
255+
block_M = 128
256+
block_N = 128 if D_HEAD <= 64 else 32
257+
mod_prep = cached(flashattn_bwd_preprocess, [2], BATCH, H, N_CTX, D_HEAD)
258+
mod_post = cached(flashattn_bwd_postprocess, [1], BATCH, H, N_CTX, D_HEAD)
259+
delta = mod_prep(o, do)
260+
mod = cached(flashattn_bwd, [6, 7, 8], BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M,
261+
block_N)
262+
dq, dk, dv = mod(q, k, v, do, lse, delta)
263+
dq = mod_post(dq)
264+
return dq, dk, dv, None
265+
266+
267+
attention = _attention.apply
268+
269+
270+
def ref_program(Q, K, V, is_causal):
271+
dim = Q.size(-1)
272+
scores = torch.einsum('bqhd,bkhd->bhqk', Q, K)
273+
scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype))
274+
if is_causal:
275+
seq_len = Q.size(1)
276+
mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device))
277+
mask = mask.unsqueeze(0).unsqueeze(0)
278+
scores = scores.masked_fill(mask == 0, float('-inf'))
279+
attention_weights = F.softmax(scores, dim=-1)
280+
output = torch.einsum('bhqk,bkhd->bqhd', attention_weights, V)
281+
return output
282+
283+
284+
if __name__ == "__main__":
285+
parser = argparse.ArgumentParser()
286+
parser.add_argument('--batch', type=int, default=8, help='Batch size')
287+
parser.add_argument('--h', type=int, default=32, help='Number of heads')
288+
parser.add_argument('--n_ctx', type=int, default=1024, help='Context size')
289+
parser.add_argument('--d_head', type=int, default=64, help='Head dimension')
290+
parser.add_argument('--casual', type=bool, default=False, help='Casual flag')
291+
args = parser.parse_args()
292+
BATCH, H, N_CTX, D_HEAD = args.batch, args.h, args.n_ctx, args.d_head
293+
casual = args.casual
294+
flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD
295+
total_flops = 5 * flops_per_matmul
296+
if casual:
297+
total_flops *= 0.5
298+
Q = (
299+
torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half,
300+
device="cuda").normal_().requires_grad_())
301+
K = torch.empty_like(Q).normal_().requires_grad_()
302+
V = torch.empty_like(Q).normal_().requires_grad_()
303+
dO = torch.randn_like(Q)
304+
O = attention(Q, K, V, casual)
305+
O.backward(dO, retain_graph=True)
306+
dQ, Q.grad = Q.grad.clone(), None
307+
dK, K.grad = K.grad.clone(), None
308+
dV, V.grad = V.grad.clone(), None
309+
310+
O_ref = ref_program(Q, K, V, casual)
311+
O_ref.backward(dO, retain_graph=True)
312+
dQ_ref, Q.grad = Q.grad.clone(), None
313+
dK_ref, K.grad = K.grad.clone(), None
314+
dV_ref, V.grad = V.grad.clone(), None
315+
316+
assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2)
317+
assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2)
318+
assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2)
319+
assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2)
320+
321+
def run():
322+
O_ref.backward(dO, retain_graph=True)
323+
324+
def run1():
325+
O.backward(dO, retain_graph=True)
326+
327+
from tilelang.profiler import do_bench
328+
329+
latency = do_bench(run, warmup=500)
330+
print("torch: {:.2f} ms".format(latency))
331+
print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9))
332+
latency = do_bench(run1, warmup=500)
333+
print("tilelang: {:.2f} ms".format(latency))
334+
print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9))

tilelang/profiler/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import tvm
1111
from tvm.relay import TensorType
1212

13+
from tilelang.engine import lower
1314
from tilelang.jit.adapter import TorchDLPackKernelAdapter
1415
from tilelang.utils.tensor import (
1516
get_tensor_supply,
@@ -244,3 +245,17 @@ def do_bench(
244245
ret = ret[0]
245246
return ret
246247
return getattr(torch, return_mode)(times).item()
248+
249+
250+
_cached = {}
251+
252+
253+
def cached(func, result_idx: List[int], *args):
254+
global _cached
255+
key = (func, tuple(result_idx), *args)
256+
if key not in _cached:
257+
program = func(*args)
258+
mod, params = lower(program)
259+
mod = TorchDLPackKernelAdapter(mod, params, result_idx)
260+
_cached[key] = mod
261+
return _cached[key]

0 commit comments

Comments
 (0)