Skip to content

Commit 7430deb

Browse files
committed
cleanup and add comments
1 parent 52f835e commit 7430deb

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

hopper/seqlen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ struct SeqlenInfoQK {
5656
? seqlen_k_static
5757
: (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static)))
5858
, cp_world_size(cp_world_size)
59-
, tot_seqlen_k(cp_tot_seqused_k == nullptr
59+
, tot_seqlen_k(cp_tot_seqused_k == nullptr and cp_world_size <= 1
6060
? seqlen_k
6161
: cp_tot_seqused_k[bidb])
6262
{

hopper/test_flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def test_flash_attn_output(
389389
],
390390
)
391391
def test_flash_attn_varlen_output(
392-
seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv_, mha_type, dtype, test_sink,
392+
seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv_, mha_type, dtype, test_sink
393393
):
394394
if has_qv_ and (d != 64 or dtype == torch.float8_e4m3fn):
395395
pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)")

hopper/test_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def construct_cp_mask(
231231
seqlen_k: Length of key sequence (local to this rank)
232232
cp_world_size: Number of context parallel ranks
233233
cp_rank: Current rank ID (0 to cp_world_size-1)
234+
cp_tot_seqlen_k: Total lengths of key sequence in cp world
234235
window_size: (left_window, right_window), -1 = infinite
235236
sink_token_length: Number of "sink" tokens that can always be attended to
236237
query_padding_mask: Which query positions are valid
@@ -350,6 +351,7 @@ def attention_ref(
350351
s_aux: (nheads)
351352
cp_world_size: Number of context parallel ranks
352353
cp_rank: Current rank ID (0 to cp_world_size-1)
354+
cp_tot_seqlen_k: (batch_size) total seqlen of k/v in cp world
353355
Output:
354356
output: (batch_size, seqlen_q, nheads, head_dim_v)
355357
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout

0 commit comments

Comments
 (0)