Skip to content

Commit a9d823b

Browse files
authored
[Example] Update GQA varlen fwd (#1173)
* [Example] Update GQA varlen fwd * fix
1 parent 298ab48 commit a9d823b

File tree

1 file changed

+53
-41
lines changed

1 file changed

+53
-41
lines changed

examples/flash_attention/example_gqa_fwd_varlen.py

Lines changed: 53 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,32 @@ def attention_ref(
2424
dtype_og = q.dtype
2525
if upcast:
2626
q, k, v = q.float(), k.float(), v.float()
27-
dim = q.shape[-1]
28-
scale = (1.0 / dim)**0.5
29-
k = repeat(k, "b s h d -> b s (h g) d", g=q.shape[2] // k.shape[2])
30-
v = repeat(v, "b s h d -> b s (h g) d", g=q.shape[2] // v.shape[2])
27+
b, T, Hq, D = q.shape
28+
S = k.shape[1]
29+
scale = (1.0 / D)**0.5
30+
k = repeat(k, "b s h d -> b s (h g) d", g=Hq // k.shape[2])
31+
v = repeat(v, "b s h d -> b s (h g) d", g=Hq // v.shape[2])
3132
scores = torch.einsum("bthd,bshd->bhts", q, k)
33+
left, right = window_size
34+
left = S if left is None or left < 0 else int(left)
35+
right = S if right is None or right < 0 else int(right)
36+
t_idx = torch.arange(T, device=scores.device)[:, None]
37+
s_idx = torch.arange(S, device=scores.device)[None, :]
38+
visible_ts = (s_idx >= (t_idx - left)) & (s_idx <= (t_idx + right))
39+
visible_mask = visible_ts.unsqueeze(0).unsqueeze(0)
3240
if key_padding_mask is not None:
33-
scores.masked_fill_(rearrange(~key_padding_mask, "b s -> b 1 1 s"), float("-inf"))
41+
k_keep = rearrange(key_padding_mask, "b s -> b 1 1 s")
42+
visible_mask = visible_mask & k_keep
43+
neg_inf = torch.finfo(scores.dtype).min
3444
scores = scores * scale
45+
scores = scores.masked_fill(~visible_mask, neg_inf)
3546
attention = torch.softmax(scores, dim=-1).to(v.dtype)
36-
3747
if query_padding_mask is not None:
38-
attention = attention.masked_fill(rearrange(~query_padding_mask, "b s -> b 1 s 1"), 0.0)
48+
q_keep = rearrange(query_padding_mask, "b t -> b 1 t 1")
49+
attention = attention.masked_fill(~q_keep, 0.0)
3950
output = torch.einsum("bhts,bshd->bthd", attention, v)
4051
if query_padding_mask is not None:
41-
output.masked_fill_(rearrange(~query_padding_mask, "b s -> b s 1 1"), 0.0)
52+
output = output.masked_fill(rearrange(~query_padding_mask, "b t -> b t 1 1"), 0.0)
4253
return output.to(dtype=dtype_og), attention.to(dtype=dtype_og)
4354

4455

@@ -91,60 +102,63 @@ def main(
91102
scores_sum = T.alloc_fragment([block_M], accum_dtype)
92103
logsum = T.alloc_fragment([block_M], accum_dtype)
93104

105+
T.annotate_layout({
106+
O_shared: tilelang.layout.make_swizzled_layout(O_shared),
107+
Q_shared: tilelang.layout.make_swizzled_layout(Q_shared),
108+
})
109+
94110
batch_idx = bz
95111
head_idx = by
96112
kv_head_idx = head_idx // groups
97113

98114
q_start_idx = cu_seqlens_q[batch_idx]
99-
k_start_idx = cu_seqlens_k[batch_idx]
100-
v_start_idx = cu_seqlens_k[batch_idx]
115+
kv_start_idx = cu_seqlens_k[batch_idx]
101116
q_end_idx = cu_seqlens_q[batch_idx + 1]
102117
k_end_idx = cu_seqlens_k[batch_idx + 1]
103-
v_end_idx = cu_seqlens_k[batch_idx + 1]
104118

105119
q_current_seqlen = q_end_idx - q_start_idx
106-
k_current_seqlen = k_end_idx - k_start_idx
107-
v_current_seqlen = v_end_idx - v_start_idx
120+
kv_current_seqlen = k_end_idx - kv_start_idx
108121

109122
T.copy(
110123
Q_unpad[q_start_idx + bx * block_M:q_start_idx + (bx + 1) * block_M, head_idx, :],
111124
Q_shared)
112-
for i, d in T.Parallel(block_M, dim):
113-
if bx * block_M + i >= q_current_seqlen:
114-
Q_shared[i, d] = 0
115125

116126
T.fill(acc_o, 0)
117127
T.fill(logsum, 0)
118128
T.fill(scores_max, -T.infinity(accum_dtype))
119129

120-
loop_range = T.ceildiv(k_current_seqlen, block_N)
130+
loop_range = (
131+
T.min(
132+
T.ceildiv(q_current_seqlen +
133+
(bx + 1) * block_M, block_N), T.ceildiv(kv_current_seqlen, block_N))
134+
if is_causal else T.ceildiv(kv_current_seqlen, block_N))
121135

122136
for k in T.Pipelined(loop_range, num_stages=num_stages):
123137
T.copy(
124-
K_unpad[k_start_idx + k * block_N:k_start_idx + (k + 1) * block_N,
138+
K_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
125139
kv_head_idx, :], K_shared)
126-
for i, d in T.Parallel(block_N, dim):
127-
if k * block_N + i >= k_current_seqlen:
128-
K_shared[i, d] = 0
129140

130141
if is_causal:
131142
for i, j in T.Parallel(block_M, block_N):
132-
acc_s[i, j] = T.if_then_else((bx * block_M + i >= k * block_N + j) and
133-
(bx * block_M + i >= q_current_seqlen or
134-
k * block_N + j >= k_current_seqlen),
135-
-T.infinity(acc_s.dtype), 0)
143+
acc_s[i,
144+
j] = T.if_then_else((bx * block_M + i < k * block_N + j) or
145+
(bx * block_M + i >= q_current_seqlen or
146+
k * block_N + j >= kv_current_seqlen), -1e9, 0)
136147
else:
137148
for i, j in T.Parallel(block_M, block_N):
138149
acc_s[i, j] = T.if_then_else((bx * block_M + i >= q_current_seqlen or
139-
k * block_N + j >= k_current_seqlen),
140-
-T.infinity(acc_s.dtype), 0)
150+
k * block_N + j >= kv_current_seqlen), -1e9,
151+
0)
141152

142153
T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
143154

144155
T.copy(scores_max, scores_max_prev)
145156
T.fill(scores_max, -T.infinity(accum_dtype))
146157
T.reduce_max(acc_s, scores_max, dim=1, clear=False)
147158

159+
for i in T.Parallel(block_M):
160+
scores_max[i] = T.max(scores_max[i], scores_max_prev[i])
161+
148162
for i in T.Parallel(block_M):
149163
scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale)
150164
for i, j in T.Parallel(block_M, block_N):
@@ -158,11 +172,8 @@ def main(
158172
acc_o[i, j] *= scores_scale[i]
159173

160174
T.copy(
161-
V_unpad[v_start_idx + k * block_N:v_start_idx + (k + 1) * block_N,
175+
V_unpad[kv_start_idx + k * block_N:kv_start_idx + (k + 1) * block_N,
162176
kv_head_idx, :], V_shared)
163-
for i, d in T.Parallel(block_N, dim):
164-
if k * block_N + i >= v_current_seqlen:
165-
V_shared[i, d] = 0
166177

167178
T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow)
168179

@@ -191,8 +202,7 @@ def main(batch: int = 1,
191202

192203
tilelang.testing.set_random_seed(0)
193204

194-
causal = False
195-
if causal:
205+
if is_causal:
196206
total_flops *= 0.5
197207

198208
tilelang.testing.set_random_seed(0)
@@ -201,9 +211,9 @@ def main(batch: int = 1,
201211
device = torch.device("cuda")
202212

203213
head_kv = heads // groups
204-
q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device, requires_grad=True)
205-
k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True)
206-
v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device, requires_grad=True)
214+
q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device)
215+
k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
216+
v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device)
207217

208218
query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random")
209219
key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random")
@@ -236,10 +246,10 @@ def main(batch: int = 1,
236246
heads,
237247
dim,
238248
is_causal,
239-
block_M=64,
240-
block_N=64,
241-
num_stages=1,
242-
threads=128)
249+
block_M=128,
250+
block_N=128,
251+
num_stages=2,
252+
threads=256)
243253

244254
out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)
245255
out = output_pad_fn(out_unpad)
@@ -255,7 +265,9 @@ def main(batch: int = 1,
255265
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)
256266
print("All checks passed.✅")
257267
latency = do_bench(
258-
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q))
268+
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q),
269+
_n_warmup=5,
270+
_n_repeat=5)
259271
print("Tile-lang: {:.2f} ms".format(latency))
260272
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
261273

0 commit comments

Comments
 (0)