Skip to content

Commit 2cd9f3f

Browse files
committed
cleanup and add comments
Signed-off-by: Ming Yang <minos.future@gmail.com>
1 parent 1e955be commit 2cd9f3f

File tree

4 files changed

+5
-3
lines changed

4 files changed

+5
-3
lines changed

hopper/flash_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ inline bool get_pack_gqa(Flash_fwd_params const& params) {
435435
// Always enable PackGQA for special case of hdim = 64, qheads/kvheads = 8, local attention
436436
// TODO: investigate more cases where PackGQA improves perf due to better tile quantization
437437
bool const packgqa_override = params.arch >= 90 && (params.h / params.h_k) == 8 &&
438-
params.is_local &&
438+
params.is_local &&
439439
params.d == 64 && (params.dv == params.d);
440440
if (packgqa_override) { return true; }
441441
#ifdef FLASHATTENTION_DISABLE_PACKGQA

hopper/seqlen.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ struct SeqlenInfoQK {
5656
? seqlen_k_static
5757
: (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static)))
5858
, cp_world_size(cp_world_size)
59-
, tot_seqlen_k(cp_tot_seqused_k == nullptr
59+
, tot_seqlen_k(cp_tot_seqused_k == nullptr and cp_world_size <= 1
6060
? seqlen_k
6161
: cp_tot_seqused_k[bidb])
6262
{

hopper/test_flash_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def test_flash_attn_output(
389389
],
390390
)
391391
def test_flash_attn_varlen_output(
392-
seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv_, mha_type, dtype, test_sink,
392+
seqlen_q, seqlen_k, d, add_unused_qkv, causal, local, softcap, deterministic, has_qv_, mha_type, dtype, test_sink
393393
):
394394
if has_qv_ and (d != 64 or dtype == torch.float8_e4m3fn):
395395
pytest.skip("Has Qv requires hdim 64 and dtype to be float16 or bfloat16 (not float8_e4m3fn)")

hopper/test_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,7 @@ def construct_cp_mask(
231231
seqlen_k: Length of key sequence (local to this rank)
232232
cp_world_size: Number of context parallel ranks
233233
cp_rank: Current rank ID (0 to cp_world_size-1)
234+
cp_tot_seqlen_k: Total lengths of key sequence in cp world
234235
window_size: (left_window, right_window), -1 = infinite
235236
sink_token_length: Number of "sink" tokens that can always be attended to
236237
query_padding_mask: Which query positions are valid
@@ -350,6 +351,7 @@ def attention_ref(
350351
s_aux: (nheads)
351352
cp_world_size: Number of context parallel ranks
352353
cp_rank: Current rank ID (0 to cp_world_size-1)
354+
cp_tot_seqlen_k: (batch_size) total seqlen of k/v in cp world
353355
Output:
354356
output: (batch_size, seqlen_q, nheads, head_dim_v)
355357
attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout

0 commit comments

Comments
 (0)