Skip to content

Commit ce37406

Browse files
committed
remove format; simplify tot_seqlen_k handling
Signed-off-by: Ming Yang <minos.future@gmail.com>
1 parent efc45c0 commit ce37406

File tree

7 files changed

+36
-39
lines changed

7 files changed

+36
-39
lines changed

hopper/block.h

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -38,17 +38,13 @@ struct BlockMN {
3838
// TODO: check off-by-1 error
3939
if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
4040
// If local, blocking (m_idx_max - m_idx_min + window_size_right + window_size_left)
41-
if (seqlen_info.cp_world_size > 1) {
42-
n_block_max = std::min(n_block_max,
43-
cute::ceil_div(
44-
cute::ceil_div(m_idx_max + seqlen_info.cp_tot_seqlen_k - seqlen_q + window_size_right - seqlen_info.cp_rank,
45-
seqlen_info.cp_world_size),
46-
kBlockN));
47-
} else {
48-
n_block_max = std::min(n_block_max,
49-
cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right,
50-
kBlockN));
51-
}
41+
// when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1.
42+
// cp_world_size is guaranteed to be greater than 0
43+
n_block_max = std::min(n_block_max,
44+
cute::ceil_div(
45+
cute::ceil_div(m_idx_max + seqlen_info.tot_seqlen_k - seqlen_q + window_size_right - seqlen_info.cp_rank,
46+
seqlen_info.cp_world_size),
47+
kBlockN));
5248
}
5349
// Now, only adjust n_block_min if split
5450
int n_block_min = 0;

hopper/flash_api.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ inline bool get_pack_gqa(Flash_fwd_params const& params) {
434434
if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; }
435435
// Always enable PackGQA for special case of hdim = 64, qheads/kvheads = 8, local attention
436436
// TODO: investigate more cases where PackGQA improves perf due to better tile quantization
437-
bool const packgqa_override = params.arch >= 90 && (params.h / params.h_k) == 8 &&
437+
bool const packgqa_override = params.arch >= 90 && (params.h / params.h_k) == 8 &&
438438
params.is_local &&
439439
params.d == 64 && (params.dv == params.d);
440440
if (packgqa_override) { return true; }
@@ -787,7 +787,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
787787
}
788788
#ifdef FLASHATTENTION_DISABLE_HDIMDIFF64
789789
TORCH_CHECK(head_size > 64, "This flash attention build does not support hdim != hdim_v when hdim <= 64");
790-
#endif
790+
#endif
791791
#ifdef FLASHATTENTION_DISABLE_HDIMDIFF192
792792
TORCH_CHECK(head_size <= 64, "This flash attention build does not support hdim != hdim_v when hdim in (128, 192]");
793793
#endif
@@ -1161,6 +1161,9 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
11611161
params.cp_rank = cp_rank;
11621162
params.cp_tot_seqused_k = cp_tot_seqused_k_.has_value() ?
11631163
static_cast<int *>(cp_tot_seqused_k_.value().data_ptr()) : nullptr;
1164+
TORCH_CHECK(cp_world_size > 0, "cp_world_size must be positive, required by downstream unified code path. Use 1 if CP is not enabled.");
1165+
TORCH_CHECK(cp_world_size != 1 || cp_rank == 0, "When context parallelism is disabled, cp_rank must be zero");
1166+
TORCH_CHECK(cp_world_size == 1 || cp_tot_seqused_k_.has_value(), "cp_tot_seqused_k_ must be provided when context parallelism is enabled.");
11641167

11651168
#ifdef FLASHATTENTION_DISABLE_LOCAL
11661169
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");

hopper/flash_attn_interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def backward(ctx, dout, *args):
241241
ctx.causal,
242242
ctx.window_size,
243243
ctx.softcap,
244-
ctx.deterministic,
244+
ctx.deterministic,
245245
)
246246
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
247247
return dqkv, None, None, None, None, None, None, None, None, None, None

hopper/flash_fwd_launch_template.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
8989
cute::conditional_return<!V_colmajor>(
9090
make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),
9191
make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));
92-
9392
typename CollectiveMainloop::Arguments mainloop_args {
9493
static_cast<Element const*>(params.q_ptr),
9594
{seqlen_q, params.d, params.h, batch_q}, // shape_Q

hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ struct CollectiveMainloopFwdSm90 {
547547
return nullptr;
548548
}
549549
}();
550-
550+
551551
auto const shape_Qv_packed = cute::conditional_return<!PackGQA>(
552552
shape_Qv,
553553
make_shape(make_shape(qhead_per_khead, get<0>(shape_Qv)), get<1>(shape_Qv), get<2>(args.shape_K), get<3>(shape_Qv))
@@ -1007,7 +1007,6 @@ struct CollectiveMainloopFwdSm90 {
10071007
static constexpr int kBlockN = get<1>(TileShape_MNK{});
10081008

10091009
// can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda
1010-
// block index
10111010
int const m_block = get<0>(block_coord);
10121011
int const bidh = get<1>(block_coord);
10131012
int const bidb = get<2>(block_coord);
@@ -1103,7 +1102,7 @@ struct CollectiveMainloopFwdSm90 {
11031102
flash::Mask<kBlockM, kBlockN, PackGQA, TiledMmaQK> mask(
11041103
thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 - n_offset /*sink_token_length*/,
11051104
params.qhead_per_khead_divmod,
1106-
params.cp_world_size, params.cp_rank, seqlen_info.cp_tot_seqlen_k
1105+
params.cp_world_size, params.cp_rank, seqlen_info.tot_seqlen_k
11071106
);
11081107

11091108
float softcap_val = params.softcap_val;
@@ -1211,7 +1210,6 @@ struct CollectiveMainloopFwdSm90 {
12111210
}
12121211

12131212
if constexpr (IntraWGOverlap) {
1214-
12151213
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{}));
12161214
consumer_wait(pipeline_k, smem_pipe_read);
12171215
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
@@ -1283,8 +1281,7 @@ struct CollectiveMainloopFwdSm90 {
12831281
};
12841282

12851283
if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking
1286-
auto mask_fn = [&](auto& tSrS, int n_block) {
1287-
mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
1284+
auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
12881285
int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM);
12891286
// If local, blocking (window_size_right + window_size_left)
12901287
int const n_block_min_causal_local_mask =
@@ -1297,13 +1294,15 @@ struct CollectiveMainloopFwdSm90 {
12971294

12981295
int const m_idx_max = !PackGQA ? (m_block + 1) * kBlockM : params.qhead_per_khead_divmod.divide((m_block + 1) * kBlockM - 1) + 1;
12991296
// If local, blocking (m_idx_max - m_idx_min)
1297+
// when cp is not enabled, tot_seqlen_k is equal to seqlen_k, and cp_world_size is 1.
1298+
// cp_world_size is guaranteed to be greater than 0
13001299
int const n_block_min_before_local_mask = !Is_local
13011300
? n_block_min
13021301
: std::max(n_block_min,
1303-
cute::ceil_div(m_idx_max +
1304-
params.cp_world_size * seqlen_k -
1305-
seqlen_q - params.window_size_left,
1306-
params.cp_world_size * kBlockN));
1302+
cute::ceil_div(
1303+
cute::ceil_div(m_idx_max + seqlen_info.tot_seqlen_k - seqlen_q - params.window_size_left - seqlen_info.cp_rank,
1304+
seqlen_info.cp_world_size),
1305+
kBlockN));
13071306
auto no_mask_fn = [](auto& tSrS, int n_block) { };
13081307
#pragma unroll 1
13091308
for (; n_block >= n_block_min_before_local_mask; --n_block) {
@@ -1429,7 +1428,7 @@ struct CollectiveMainloopFwdSm90 {
14291428
// Tensor scores_scale = softmax.finalize(v_descale);
14301429
Tensor scores_scale = make_tensor_like(softmax.row_max);
14311430
finalize_dispatch(scores_scale, v_descale);
1432-
1431+
14331432
if constexpr (LargeHeadDimV) {
14341433
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PEmpty) /*id*/);
14351434
store_scales(scores_scale, smem_pipe_read.index());

hopper/mask.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ struct Mask {
2323
int const seqlen_q, seqlen_k;
2424
int const window_size_left, window_size_right, sink_token_length;
2525
cutlass::FastDivmod const qhead_per_khead_divmod;
26-
int const cp_world_size, cp_rank, cp_tot_seqlen_k;
26+
int const cp_world_size, cp_rank, tot_seqlen_k;
2727

2828
CUTLASS_DEVICE
2929
Mask(const int thread_idx, const int seqlen_q, const int seqlen_k,
3030
const int window_size_left, const int window_size_right, const int sink_token_length,
3131
cutlass::FastDivmod const &qhead_per_khead_divmod,
32-
const int cp_world_size = 1, const int cp_rank = 0, const int cp_tot_seqlen_k = 0)
32+
const int cp_world_size = 1, const int cp_rank = 0, const int tot_seqlen_k = 0)
3333
: thread_idx(thread_idx)
3434
, seqlen_q(seqlen_q)
3535
, seqlen_k(seqlen_k)
@@ -39,7 +39,7 @@ struct Mask {
3939
, qhead_per_khead_divmod(qhead_per_khead_divmod)
4040
, cp_world_size(cp_world_size)
4141
, cp_rank(cp_rank)
42-
, cp_tot_seqlen_k(cp_tot_seqlen_k)
42+
, tot_seqlen_k(tot_seqlen_k)
4343
{
4444
};
4545

@@ -103,8 +103,8 @@ struct Mask {
103103
if (cp_world_size > 1) {
104104
int local_k_idx = int(get<Col>(t0ScS_rowcol(_0{}, n))) + get<Col>(tScS_rowcol(_0{}, _0{})) + n_block * kBlockN;
105105
int abs_k_idx = local_k_idx * cp_world_size + cp_rank;
106-
int k_limit = row_idx + cp_tot_seqlen_k - seqlen_q;
107-
if (abs_k_idx > k_limit || (Seqlenk_mask && abs_k_idx >= cp_tot_seqlen_k)) {
106+
int k_limit = row_idx + tot_seqlen_k - seqlen_q;
107+
if (abs_k_idx > k_limit || (Seqlenk_mask && abs_k_idx >= tot_seqlen_k)) {
108108
tSrS_rowcol(m, n) = -INFINITY;
109109
}
110110
} else {

hopper/seqlen.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ struct SeqlenInfoQK {
3434
int const offset_q, offset_k, offset_q_padded;
3535
int const seqlen_q, seqlen_k;
3636
int const cp_world_size;
37-
int const cp_tot_seqlen_k;
37+
int const tot_seqlen_k;
3838

3939
CUTLASS_DEVICE
4040
SeqlenInfoQK(int const bidb, int const seqlen_q_static, int const seqlen_k_static,
@@ -56,9 +56,9 @@ struct SeqlenInfoQK {
5656
? seqlen_k_static
5757
: (seqused_k ? seqused_k[bidb] : (cu_seqlens_k ? cu_seqlens_k[bidb + 1] - cu_seqlens_k[bidb] : seqlen_k_static)))
5858
, cp_world_size(cp_world_size)
59-
, cp_tot_seqlen_k(cp_tot_seqused_k == nullptr
60-
? 0
61-
: cp_tot_seqused_k[bidb])
59+
, tot_seqlen_k(cp_tot_seqused_k == nullptr
60+
? seqlen_k
61+
: cp_tot_seqused_k[bidb])
6262
{
6363
}
6464

@@ -74,7 +74,7 @@ struct SeqlenInfoQKNewK {
7474
int const seqlen_q, seqlen_k_og, seqlen_k_new, seqlen_k, seqlen_rotary;
7575
int const cp_world_size;
7676
int const cp_rank;
77-
int const cp_tot_seqlen_k;
77+
int const tot_seqlen_k;
7878

7979
CUTLASS_DEVICE
8080
SeqlenInfoQKNewK(int const bidb, int const seqlen_q_static, int const seqlen_k_static, int const shape_K_new_0,
@@ -100,9 +100,9 @@ struct SeqlenInfoQKNewK {
100100
, seqlen_rotary(!AppendKV || !seqlens_rotary ? seqlen_k_og + leftpad_k : seqlens_rotary[bidb])
101101
, cp_world_size(cp_world_size)
102102
, cp_rank(cp_rank)
103-
, cp_tot_seqlen_k(cp_tot_seqused_k == nullptr
104-
? 0
105-
: cp_tot_seqused_k[bidb])
103+
, tot_seqlen_k(cp_tot_seqused_k == nullptr
104+
? seqlen_k
105+
: cp_tot_seqused_k[bidb])
106106
{
107107
}
108108

0 commit comments

Comments
 (0)