Skip to content

Commit 1e955be

Browse files
committed
update test
Signed-off-by: Ming Yang <minos.future@gmail.com>
1 parent ce37406 commit 1e955be

File tree

4 files changed

+30
-10
lines changed

4 files changed

+30
-10
lines changed

hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1284,8 +1284,13 @@ struct CollectiveMainloopFwdSm90 {
12841284
auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
12851285
int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM);
12861286
// If local, blocking (window_size_right + window_size_left)
1287+
// when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1.
1288+
// cp_world_size is guaranteed to be greater than 0
12871289
int const n_block_min_causal_local_mask =
1288-
std::max(n_block_min, (m_idx_min + seqlen_k - seqlen_q + params.window_size_right) / kBlockN);
1290+
std::max(n_block_min,
1291+
(m_idx_min + seqlen_info.tot_seqlen_k - seqlen_q + params.window_size_right) /
1292+
seqlen_info.cp_world_size /
1293+
kBlockN);
12891294
#pragma unroll 1
12901295
for (; n_block >= n_block_min_causal_local_mask; --n_block) {
12911296
fwd_step(n_block, mask_fn, cute::true_type{} /*check_inf*/);

hopper/seqlen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ struct SeqlenInfoQKNewK {
100100
, seqlen_rotary(!AppendKV || !seqlens_rotary ? seqlen_k_og + leftpad_k : seqlens_rotary[bidb])
101101
, cp_world_size(cp_world_size)
102102
, cp_rank(cp_rank)
103-
, tot_seqlen_k(cp_tot_seqused_k == nullptr
103+
, tot_seqlen_k(cp_tot_seqused_k == nullptr and cp_world_size <= 1
104104
? seqlen_k
105105
: cp_tot_seqused_k[bidb])
106106
{

hopper/test_flash_attn.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,20 @@
121121
(4224, 4224),
122122
],
123123
)
124+
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
124125
@pytest.mark.parametrize(
125-
"cp_world_size", [4, 2, 1], # 1 means disabling cp
126+
"cp_world_size,cp_rank,cp_tot_seqlen_k_offset",
127+
[
128+
(8,0,1),
129+
(8,7,0),
130+
(4,3,2),
131+
(2,0,0),
132+
(1,0,0), # 1 means disabling cp
133+
],
126134
)
127-
#@pytest.mark.parametrize('seqlen_q,seqlen_k', [(1, 1)])
128135
def test_flash_attn_output(
129136
seqlen_q, seqlen_k, d, causal, local, softcap, V_colmajor, deterministic, has_qv_, mha_type, dtype, test_sink,
130-
cp_world_size,
137+
cp_world_size, cp_rank, cp_tot_seqlen_k_offset
131138
):
132139
if V_colmajor and (seqlen_k % 16 != 0 or dtype != torch.float8_e4m3fn):
133140
pytest.skip("V_colmajor requires seqlen_k to be a multiple of 16 and dtype to be float8_e4m3fn")
@@ -157,7 +164,8 @@ def test_flash_attn_output(
157164
s_aux = torch.randn(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None
158165
# s_aux = torch.ones(nheads, device=device, dtype=torch.bfloat16) * 4 if test_sink else None
159166
# print("s_aux ", s_aux)
160-
cp_rank = 0
167+
cp_tot_seqlen_k = seqlen_k * cp_world_size + cp_tot_seqlen_k_offset
168+
cp_tot_seqlen_k = torch.full((batch_size,), cp_tot_seqlen_k, device=device, dtype=torch.int32)
161169
if test_sink:
162170
dv_vals = [d]
163171
for dv in dv_vals:
@@ -175,7 +183,7 @@ def test_flash_attn_output(
175183
else:
176184
qv_ref = None
177185
# Put window_size after QKV randn so that window_size changes from test to test
178-
window_size = (-1, -1) if not local else torch.randint(0, seqlen_k * cp_world_size, (2,))
186+
window_size = (-1, -1) if not local else torch.randint(0, cp_tot_seqlen_k[0], (2,))
179187
# window_size = (-1, -1) if not local else (16, 0)
180188
if dtype == torch.float8_e4m3fn:
181189
q_descale, k_descale, v_descale = [torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) * 2 for _ in range(3)]
@@ -199,6 +207,7 @@ def test_flash_attn_output(
199207
s_aux=s_aux,
200208
cp_world_size=cp_world_size,
201209
cp_rank=cp_rank,
210+
cp_tot_seqlen_k=cp_tot_seqlen_k,
202211
)
203212
out_pt, attn_pt = attention_ref(
204213
q_ref,
@@ -217,6 +226,7 @@ def test_flash_attn_output(
217226
s_aux=s_aux,
218227
cp_world_size=cp_world_size,
219228
cp_rank=cp_rank,
229+
cp_tot_seqlen_k=cp_tot_seqlen_k,
220230
)
221231

222232
# qk = torch.einsum('bshd,bthd->bhst', q_ref, k_ref).float()
@@ -251,6 +261,7 @@ def test_flash_attn_output(
251261
s_aux=s_aux,
252262
cp_world_size=cp_world_size,
253263
cp_rank=cp_rank,
264+
cp_tot_seqused_k=cp_tot_seqlen_k,
254265
)
255266
print("Pack GQA =", pack_gqa)
256267
print("Num splits =", num_splits)
@@ -378,7 +389,7 @@ def test_flash_attn_output(
378389
],
379390
)
380391
def test_flash_attn_varlen_output(
381-
seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv_, mha_type, dtype, test_sink
392+
seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv_, mha_type, dtype, test_sink,
382393
):
383394
if has_qv_ and (d != 64 or dtype == torch.float8_e4m3fn):
384395
pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)")

hopper/test_util.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,15 @@ def construct_local_mask(
165165
device=None,
166166
cp_world_size=1,
167167
cp_rank=0,
168+
cp_tot_seqlen_k=None,
168169
):
169170
if cp_world_size > 1:
170171
return construct_cp_mask(
171172
seqlen_q,
172173
seqlen_k,
173174
cp_world_size=cp_world_size,
174175
cp_rank=cp_rank,
176+
cp_tot_seqlen_k=cp_tot_seqlen_k,
175177
window_size=window_size,
176178
sink_token_length=sink_token_length,
177179
query_padding_mask=query_padding_mask,
@@ -209,6 +211,7 @@ def construct_cp_mask(
209211
seqlen_k,
210212
cp_world_size=1,
211213
cp_rank=0,
214+
cp_tot_seqlen_k=None,
212215
window_size=(-1, -1), # -1 means infinite window size
213216
sink_token_length=0,
214217
query_padding_mask=None,
@@ -250,7 +253,7 @@ def construct_cp_mask(
250253

251254
# Calculate effective sequence lengths
252255
sk = (
253-
torch.tensor(seqlen_k * cp_world_size, device=device, dtype=torch.long) # Global seqlen_k for DCP
256+
cp_tot_seqlen_k[0]
254257
if key_padding_mask is None
255258
else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") * cp_world_size
256259
)
@@ -300,7 +303,6 @@ def construct_cp_mask(
300303
),
301304
)
302305

303-
print(f"cp {mask=}")
304306
return mask
305307

306308

@@ -326,6 +328,7 @@ def attention_ref(
326328
s_aux=None,
327329
cp_world_size=1,
328330
cp_rank=0,
331+
cp_tot_seqlen_k=None,
329332
):
330333
"""
331334
Arguments:
@@ -396,6 +399,7 @@ def attention_ref(
396399
device=q.device,
397400
cp_world_size=cp_world_size,
398401
cp_rank=cp_rank,
402+
cp_tot_seqlen_k=cp_tot_seqlen_k,
399403
)
400404
scores.masked_fill_(local_mask, float("-inf"))
401405
if attn_bias is not None:

0 commit comments

Comments
 (0)