Skip to content

Commit efc45c0

Browse files
committed
add cp_tot_seqused_k to calc mask and block boundary
Signed-off-by: Ming Yang <minos.future@gmail.com>
1 parent e3f796f commit efc45c0

12 files changed

+74
-21
lines changed

hopper/block.h

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,17 @@ struct BlockMN {
3838
// TODO: check off-by-1 error
3939
if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
4040
// If local, blocking (m_idx_max - m_idx_min + window_size_right + window_size_left)
41-
n_block_max = std::min(n_block_max,
42-
cute::ceil_div(m_idx_max +
43-
seqlen_info.cp_world_size * seqlen_k -
44-
seqlen_q + window_size_right,
45-
seqlen_info.cp_world_size * kBlockN));
41+
if (seqlen_info.cp_world_size > 1) {
42+
n_block_max = std::min(n_block_max,
43+
cute::ceil_div(
44+
cute::ceil_div(m_idx_max + seqlen_info.cp_tot_seqlen_k - seqlen_q + window_size_right - seqlen_info.cp_rank,
45+
seqlen_info.cp_world_size),
46+
kBlockN));
47+
} else {
48+
n_block_max = std::min(n_block_max,
49+
cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right,
50+
kBlockN));
51+
}
4652
}
4753
// Now, only adjust n_block_min if split
4854
int n_block_min = 0;

hopper/flash.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ struct Flash_fwd_params : public Qkv_params {
165165
// CP (Context Parallelism) parameters
166166
int cp_world_size;
167167
int cp_rank;
168+
int *__restrict__ cp_tot_seqused_k;
168169
};
169170

170171
////////////////////////////////////////////////////////////////////////////////////////////////////

hopper/flash_api.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
703703
int const sm_margin,
704704
std::optional<const at::Tensor> &s_aux_, // (h)
705705
int const cp_world_size, // context parallelism (cp) world size
706-
int const cp_rank // cp rank
706+
int const cp_rank, // cp rank
707+
std::optional<const at::Tensor> &cp_tot_seqused_k_ // b. total seqused_k in cp world
707708
) {
708709

709710
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -841,6 +842,12 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
841842
CHECK_DEVICE(seqused_k); CHECK_CONTIGUOUS(seqused_k);
842843
CHECK_SHAPE(seqused_k, batch_size);
843844
}
845+
if (cp_tot_seqused_k_.has_value()) {
846+
auto cp_tot_seqused_k = cp_tot_seqused_k_.value();
847+
TORCH_CHECK(cp_tot_seqused_k.dtype() == torch::kInt32, "seqused_k must have dtype int32");
848+
CHECK_DEVICE(cp_tot_seqused_k); CHECK_CONTIGUOUS(cp_tot_seqused_k);
849+
CHECK_SHAPE(cp_tot_seqused_k, batch_size);
850+
}
844851

845852
if (leftpad_k_.has_value()) {
846853
auto leftpad_k = leftpad_k_.value();
@@ -1152,6 +1159,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
11521159

11531160
params.cp_world_size = cp_world_size;
11541161
params.cp_rank = cp_rank;
1162+
params.cp_tot_seqused_k = cp_tot_seqused_k_.has_value() ?
1163+
static_cast<int *>(cp_tot_seqused_k_.value().data_ptr()) : nullptr;
11551164

11561165
#ifdef FLASHATTENTION_DISABLE_LOCAL
11571166
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");

hopper/flash_api_torch_lib.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
5454
int const sm_margin,
5555
std::optional<const at::Tensor> &s_aux_,
5656
int const cp_world_size,
57-
int const cp_rank
57+
int const cp_rank,
58+
std::optional<const at::Tensor> &cp_tot_seqused_k
5859
);
5960

6061
// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
@@ -124,7 +125,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
124125
" int sm_margin,"
125126
" Tensor? s_aux,"
126127
" int cp_world_size,"
127-
" int cp_rank) -> Tensor[]");
128+
" int cp_rank,"
129+
" Tensor? cp_tot_seqused_k) -> Tensor[]");
128130
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
129131

130132
ops.def("get_scheduler_metadata("

hopper/flash_attn_interface.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def _flash_attn_forward(
5151
sm_margin=0,
5252
s_aux=None,
5353
cp_world_size=1,
54-
cp_rank=0):
54+
cp_rank=0,
55+
cp_tot_seqused_k=None):
5556
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
5657
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
5758
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
@@ -99,7 +100,8 @@ def _flash_attn_forward(
99100
sm_margin,
100101
s_aux,
101102
cp_world_size,
102-
cp_rank
103+
cp_rank,
104+
cp_tot_seqused_k,
103105
)
104106
return out, softmax_lse, *rest
105107

@@ -266,6 +268,7 @@ def forward(
266268
s_aux=None,
267269
cp_world_size=1,
268270
cp_rank=0,
271+
cp_tot_seqused_k=None,
269272
):
270273
if softmax_scale is None:
271274
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
@@ -293,6 +296,7 @@ def forward(
293296
s_aux=s_aux,
294297
cp_world_size=cp_world_size,
295298
cp_rank=cp_rank,
299+
cp_tot_seqused_k=cp_tot_seqused_k,
296300
)
297301
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
298302
ctx.save_for_backward(q, k, v, out, softmax_lse)
@@ -361,6 +365,7 @@ def forward(
361365
s_aux=None,
362366
cp_world_size=1,
363367
cp_rank=0,
368+
cp_tot_seqused_k=0,
364369
):
365370
if softmax_scale is None:
366371
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
@@ -392,6 +397,7 @@ def forward(
392397
s_aux=s_aux,
393398
cp_world_size=cp_world_size,
394399
cp_rank=cp_rank,
400+
cp_tot_seqused_k=cp_tot_seqused_k,
395401
)
396402
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
397403
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
@@ -511,6 +517,7 @@ def flash_attn_func(
511517
s_aux=None,
512518
cp_world_size=1,
513519
cp_rank=0,
520+
cp_tot_seqused_k=None,
514521
):
515522
"""dropout_p should be set to 0.0 during evaluation
516523
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@@ -574,6 +581,7 @@ def flash_attn_func(
574581
s_aux,
575582
cp_world_size,
576583
cp_rank,
584+
cp_tot_seqused_k,
577585
)
578586

579587

@@ -600,6 +608,7 @@ def flash_attn_varlen_func(
600608
s_aux=None,
601609
cp_world_size=1,
602610
cp_rank=0,
611+
cp_tot_seqused_k=None,
603612
):
604613
return FlashAttnVarlenFunc.apply(
605614
q,
@@ -624,6 +633,7 @@ def flash_attn_varlen_func(
624633
s_aux,
625634
cp_world_size,
626635
cp_rank,
636+
cp_tot_seqused_k,
627637
)
628638

629639

@@ -664,6 +674,7 @@ def flash_attn_with_kvcache(
664674
s_aux=None,
665675
cp_world_size=1,
666676
cp_rank=0,
677+
cp_tot_seqused_k=None,
667678
):
668679
"""
669680
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
@@ -793,6 +804,7 @@ def flash_attn_with_kvcache(
793804
s_aux=s_aux,
794805
cp_world_size=cp_world_size,
795806
cp_rank=cp_rank,
807+
cp_tot_seqused_k=cp_tot_seqused_k,
796808
)
797809
# return (out, softmax_lse) if return_softmax_lse else out
798810
return (out, softmax_lse, *rest) if return_softmax_lse else out

hopper/flash_fwd_kernel_sm90.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,9 @@ class FlashAttnFwdSm90 {
348348
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
349349
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
350350
params.mainloop.seqlens_rotary,
351-
params.mainloop.cp_world_size
351+
params.mainloop.cp_world_size,
352+
params.mainloop.cp_rank,
353+
params.mainloop.cp_tot_seqused_k
352354
};
353355
if constexpr (AppendKV) {
354356
bool tile_new_valid = mainloop.load_kv_new(
@@ -397,7 +399,9 @@ class FlashAttnFwdSm90 {
397399
get<0>(params.mainloop.shape_K_new),
398400
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
399401
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
400-
params.mainloop.seqlens_rotary, params.mainloop.cp_world_size
402+
params.mainloop.seqlens_rotary, params.mainloop.cp_world_size,
403+
params.mainloop.cp_rank,
404+
params.mainloop.cp_tot_seqused_k
401405
};
402406
if constexpr (AppendKV) {
403407
bool tile_new_valid = mainloop.store_kv_new(

hopper/flash_fwd_launch_template.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
131131
params.seqused_q, params.seqused_k,
132132
params.leftpad_k, params.seqlens_rotary,
133133
static_cast<ElementS const*>(params.s_aux_ptr),
134-
params.cp_world_size, params.cp_rank,
134+
params.cp_world_size, params.cp_rank, params.cp_tot_seqused_k
135135
};
136136
typename CollectiveEpilogue::Arguments epilogue_args {
137137
static_cast<ElementOut*>(params.o_ptr),

hopper/mainloop_fwd_sm80.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ struct CollectiveMainloopFwdSm80 {
217217
ElementSAux const* const ptr_S_aux = nullptr;
218218
int cp_world_size;
219219
int cp_rank;
220+
int const* const cp_tot_seqused_k = nullptr;
220221
};
221222

222223
// Device side kernel params

hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,7 @@ struct CollectiveMainloopFwdSm90 {
415415
// Context parallelism (CP) parameters
416416
int const cp_world_size = 1;
417417
int const cp_rank = 0;
418+
int const* const cp_tot_seqused_k = nullptr;
418419
};
419420

420421
// Device side kernel params
@@ -474,6 +475,7 @@ struct CollectiveMainloopFwdSm90 {
474475
ElementSAux const* const ptr_S_aux = nullptr;
475476
int cp_world_size = 1;
476477
int cp_rank = 0;
478+
int const* const cp_tot_seqused_k = nullptr;
477479
};
478480

479481
static Params
@@ -590,7 +592,7 @@ struct CollectiveMainloopFwdSm90 {
590592
args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new,
591593
args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary,
592594
args.ptr_S_aux,
593-
args.cp_world_size, args.cp_rank};
595+
args.cp_world_size, args.cp_rank, args.cp_tot_seqused_k};
594596
}
595597

596598
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
@@ -1101,7 +1103,7 @@ struct CollectiveMainloopFwdSm90 {
11011103
flash::Mask<kBlockM, kBlockN, PackGQA, TiledMmaQK> mask(
11021104
thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 - n_offset /*sink_token_length*/,
11031105
params.qhead_per_khead_divmod,
1104-
params.cp_world_size, params.cp_rank
1106+
params.cp_world_size, params.cp_rank, seqlen_info.cp_tot_seqlen_k
11051107
);
11061108

11071109
float softcap_val = params.softcap_val;

hopper/mask.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ struct Mask {
2323
int const seqlen_q, seqlen_k;
2424
int const window_size_left, window_size_right, sink_token_length;
2525
cutlass::FastDivmod const qhead_per_khead_divmod;
26-
int const cp_world_size, cp_rank;
26+
int const cp_world_size, cp_rank, cp_tot_seqlen_k;
2727

2828
CUTLASS_DEVICE
2929
Mask(const int thread_idx, const int seqlen_q, const int seqlen_k,
3030
const int window_size_left, const int window_size_right, const int sink_token_length,
3131
cutlass::FastDivmod const &qhead_per_khead_divmod,
32-
const int cp_world_size = 1, const int cp_rank = 0)
32+
const int cp_world_size = 1, const int cp_rank = 0, const int cp_tot_seqlen_k = 0)
3333
: thread_idx(thread_idx)
3434
, seqlen_q(seqlen_q)
3535
, seqlen_k(seqlen_k)
@@ -39,6 +39,7 @@ struct Mask {
3939
, qhead_per_khead_divmod(qhead_per_khead_divmod)
4040
, cp_world_size(cp_world_size)
4141
, cp_rank(cp_rank)
42+
, cp_tot_seqlen_k(cp_tot_seqlen_k)
4243
{
4344
};
4445

@@ -102,8 +103,8 @@ struct Mask {
102103
if (cp_world_size > 1) {
103104
int local_k_idx = int(get<Col>(t0ScS_rowcol(_0{}, n))) + get<Col>(tScS_rowcol(_0{}, _0{})) + n_block * kBlockN;
104105
int abs_k_idx = local_k_idx * cp_world_size + cp_rank;
105-
int k_limit = row_idx + cp_world_size * seqlen_k - seqlen_q;
106-
if (abs_k_idx > k_limit || (Seqlenk_mask && abs_k_idx > cp_world_size * seqlen_k)) {
106+
int k_limit = row_idx + cp_tot_seqlen_k - seqlen_q;
107+
if (abs_k_idx > k_limit || (Seqlenk_mask && abs_k_idx >= cp_tot_seqlen_k)) {
107108
tSrS_rowcol(m, n) = -INFINITY;
108109
}
109110
} else {

0 commit comments

Comments
 (0)