Skip to content

Commit b8792e9

Browse files
committed
Support CP with query len larger than 1
1 parent 5714f9d commit b8792e9

14 files changed

+241
-30
lines changed

hopper/block.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct BlockMN {
3939
if (PackGQA) { m_idx_max = qhead_per_khead_divmod.divide(m_idx_max - 1) + 1 ; }
4040
// If local, blocking (m_idx_max - m_idx_min + window_size_right + window_size_left)
4141
n_block_max = std::min(n_block_max,
42-
cute::ceil_div(m_idx_max + seqlen_k - seqlen_q + window_size_right, kBlockN));
42+
cute::ceil_div(m_idx_max + seqlen_info.cp_world_size * seqlen_k - seqlen_q + window_size_right, kBlockN));
4343
}
4444
// Now, only adjust n_block_min if split
4545
int n_block_min = 0;

hopper/flash.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ struct Flash_fwd_params : public Qkv_params {
161161

162162
// The S extra matrix, (num_heads)
163163
void *__restrict__ s_aux_ptr;
164+
165+
// CP (Context Parallelism) parameters
166+
int cp_world_size;
167+
int cp_rank;
164168
};
165169

166170
////////////////////////////////////////////////////////////////////////////////////////////////////

hopper/flash_api.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,8 @@ inline bool get_pack_gqa(Flash_fwd_params const& params) {
434434
if (params.arch < 90 || (params.page_table && !params.pagedkv_tma) || params.num_splits > 1) { return true; }
435435
// Always enable PackGQA for special case of hdim = 64, qheads/kvheads = 8, local attention
436436
// TODO: investigate more cases where PackGQA improves perf due to better tile quantization
437-
bool const packgqa_override = params.arch >= 90 && (params.h / params.h_k) == 8 &&
438-
params.is_local &&
437+
bool const packgqa_override = params.arch >= 90 && (params.h / params.h_k) == 8 &&
438+
params.is_local &&
439439
params.d == 64 && (params.dv == params.d);
440440
if (packgqa_override) { return true; }
441441
#ifdef FLASHATTENTION_DISABLE_PACKGQA
@@ -701,7 +701,9 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
701701
int num_splits,
702702
std::optional<bool> pack_gqa_,
703703
int const sm_margin,
704-
std::optional<const at::Tensor> &s_aux_ // (h)
704+
std::optional<const at::Tensor> &s_aux_, // (h)
705+
int const cp_world_size, // context parallelism (cp) world size
706+
int const cp_rank // cp rank
705707
) {
706708

707709
auto dprops = at::cuda::getCurrentDeviceProperties();
@@ -784,7 +786,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
784786
}
785787
#ifdef FLASHATTENTION_DISABLE_HDIMDIFF64
786788
TORCH_CHECK(head_size > 64, "This flash attention build does not support hdim != hdim_v when hdim <= 64");
787-
#endif
789+
#endif
788790
#ifdef FLASHATTENTION_DISABLE_HDIMDIFF192
789791
TORCH_CHECK(head_size <= 64, "This flash attention build does not support hdim != hdim_v when hdim in (128, 192]");
790792
#endif
@@ -1148,6 +1150,9 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
11481150
params.s_aux_ptr = nullptr;
11491151
}
11501152

1153+
params.cp_world_size = cp_world_size;
1154+
params.cp_rank = cp_rank;
1155+
11511156
#ifdef FLASHATTENTION_DISABLE_LOCAL
11521157
TORCH_CHECK(!params.is_local, "This flash attention build does not support local attention.");
11531158
#endif
@@ -1664,4 +1669,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
16641669
m.def("get_scheduler_metadata", &mha_fwd_get_scheduler_metadata, "Get scheduler metadata for varlen forward pass");
16651670
}
16661671

1667-
#endif
1672+
#endif

hopper/flash_api_torch_lib.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,9 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
5252
int num_splits,
5353
std::optional<bool> pack_gqa_,
5454
int const sm_margin,
55-
std::optional<const at::Tensor> &s_aux_
55+
std::optional<const at::Tensor> &s_aux_,
56+
int const cp_world_size,
57+
int const cp_rank,
5658
);
5759

5860
// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
@@ -120,7 +122,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
120122
" int num_splits,"
121123
" bool? pack_gqa,"
122124
" int sm_margin,"
123-
" Tensor? s_aux) -> Tensor[]");
125+
" Tensor? s_aux,"
126+
" int cp_world_size,"
127+
" int cp_rank") -> Tensor[]");
124128
ops.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd));
125129
126130
ops.def("get_scheduler_metadata("
@@ -151,4 +155,4 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
151155
make_pytorch_shim(&mha_fwd_get_scheduler_metadata));
152156
}
153157
154-
REGISTER_EXTENSION(TORCH_EXTENSION_NAME);
158+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME);

hopper/flash_attn_interface.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def _flash_attn_forward(
4949
num_splits=1,
5050
pack_gqa=None,
5151
sm_margin=0,
52-
s_aux=None):
52+
s_aux=None,
53+
cp_world_size=1,
54+
cp_rank=0):
5355
q, k, k_new, v_new = [maybe_contiguous(x) for x in (q, k, k_new, v_new)]
5456
v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v
5557
cu_seqlens_q, cu_seqlens_k, cu_seqlens_k_new = [
@@ -95,7 +97,9 @@ def _flash_attn_forward(
9597
num_splits,
9698
pack_gqa,
9799
sm_margin,
98-
s_aux
100+
s_aux,
101+
cp_world_size,
102+
cp_rank
99103
)
100104
return out, softmax_lse, *rest
101105

@@ -235,7 +239,7 @@ def backward(ctx, dout, *args):
235239
ctx.causal,
236240
ctx.window_size,
237241
ctx.softcap,
238-
ctx.deterministic,
242+
ctx.deterministic,
239243
)
240244
dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension
241245
return dqkv, None, None, None, None, None, None, None, None, None, None
@@ -260,6 +264,8 @@ def forward(
260264
deterministic=False,
261265
sm_margin=0,
262266
s_aux=None,
267+
cp_world_size=1,
268+
cp_rank=0,
263269
):
264270
if softmax_scale is None:
265271
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
@@ -285,6 +291,8 @@ def forward(
285291
pack_gqa=pack_gqa,
286292
sm_margin=sm_margin,
287293
s_aux=s_aux,
294+
cp_world_size=cp_world_size,
295+
cp_rank=cp_rank,
288296
)
289297
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse)
290298
ctx.save_for_backward(q, k, v, out, softmax_lse)
@@ -351,6 +359,8 @@ def forward(
351359
deterministic=False,
352360
sm_margin=0,
353361
s_aux=None,
362+
cp_world_size=1,
363+
cp_rank=0,
354364
):
355365
if softmax_scale is None:
356366
softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5)
@@ -380,6 +390,8 @@ def forward(
380390
pack_gqa=pack_gqa,
381391
sm_margin=sm_margin,
382392
s_aux=s_aux,
393+
cp_world_size=cp_world_size,
394+
cp_rank=cp_rank,
383395
)
384396
# ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
385397
ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, seqused_q, seqused_k)
@@ -497,6 +509,8 @@ def flash_attn_func(
497509
deterministic=False,
498510
sm_margin=0,
499511
s_aux=None,
512+
cp_world_size=1,
513+
cp_rank=0,
500514
):
501515
"""dropout_p should be set to 0.0 during evaluation
502516
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
@@ -558,6 +572,8 @@ def flash_attn_func(
558572
deterministic,
559573
sm_margin,
560574
s_aux,
575+
cp_world_size,
576+
cp_rank,
561577
)
562578

563579

@@ -582,6 +598,8 @@ def flash_attn_varlen_func(
582598
deterministic=False,
583599
sm_margin=0,
584600
s_aux=None,
601+
cp_world_size=1,
602+
cp_rank=0,
585603
):
586604
return FlashAttnVarlenFunc.apply(
587605
q,
@@ -604,6 +622,8 @@ def flash_attn_varlen_func(
604622
deterministic,
605623
sm_margin,
606624
s_aux,
625+
cp_world_size,
626+
cp_rank,
607627
)
608628

609629

@@ -642,6 +662,8 @@ def flash_attn_with_kvcache(
642662
sm_margin=0, # Can be tuned if some SMs are used for communication
643663
return_softmax_lse=False,
644664
s_aux=None,
665+
cp_world_size=1,
666+
cp_rank=0,
645667
):
646668
"""
647669
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
@@ -769,6 +791,8 @@ def flash_attn_with_kvcache(
769791
pack_gqa=pack_gqa,
770792
sm_margin=sm_margin,
771793
s_aux=s_aux,
794+
cp_world_size=cp_world_size,
795+
cp_rank=cp_rank,
772796
)
773797
# return (out, softmax_lse) if return_softmax_lse else out
774798
return (out, softmax_lse, *rest) if return_softmax_lse else out

hopper/flash_fwd_kernel_sm90.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,8 @@ class FlashAttnFwdSm90 {
347347
get<0>(params.mainloop.shape_K_new),
348348
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
349349
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
350-
params.mainloop.seqlens_rotary
350+
params.mainloop.seqlens_rotary,
351+
params.mainloop.cp_world_size
351352
};
352353
if constexpr (AppendKV) {
353354
bool tile_new_valid = mainloop.load_kv_new(
@@ -396,7 +397,7 @@ class FlashAttnFwdSm90 {
396397
get<0>(params.mainloop.shape_K_new),
397398
params.mainloop.cu_seqlens_q, params.mainloop.cu_seqlens_k, params.mainloop.cu_seqlens_k_new,
398399
params.mainloop.seqused_q, params.mainloop.seqused_k, params.mainloop.leftpad_k,
399-
params.mainloop.seqlens_rotary
400+
params.mainloop.seqlens_rotary, params.mainloop.cp_world_size
400401
};
401402
if constexpr (AppendKV) {
402403
bool tile_new_valid = mainloop.store_kv_new(

hopper/flash_fwd_launch_template.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
8989
cute::conditional_return<!V_colmajor>(
9090
make_stride(params.v_row_stride, _1{}, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0),
9191
make_stride(_1{}, params.v_dim_stride, params.v_head_stride, !is_varlen_k ? params.v_batch_stride : 0));
92+
9293
typename CollectiveMainloop::Arguments mainloop_args {
9394
static_cast<Element const*>(params.q_ptr),
9495
{seqlen_q, params.d, params.h, batch_q}, // shape_Q
@@ -129,7 +130,8 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
129130
params.cu_seqlens_q, params.cu_seqlens_k, params.cu_seqlens_knew,
130131
params.seqused_q, params.seqused_k,
131132
params.leftpad_k, params.seqlens_rotary,
132-
static_cast<ElementS const*>(params.s_aux_ptr)
133+
static_cast<ElementS const*>(params.s_aux_ptr),
134+
params.cp_world_size, params.cp_rank,
133135
};
134136
typename CollectiveEpilogue::Arguments epilogue_args {
135137
static_cast<ElementOut*>(params.o_ptr),
@@ -156,6 +158,8 @@ void run_flash_fwd(Flash_fwd_params &params, cudaStream_t stream) {
156158
params.tile_count_semaphore, params.cu_seqlens_q, params.seqused_q,
157159
// params.num_m_blocks_ptr,
158160
params.num_splits_dynamic_ptr,
161+
params.cp_world_size,
162+
params.cp_rank,
159163
};
160164

161165
if (Varlen && params.num_splits_dynamic_ptr && !params.skip_scheduler_metadata_computation) {

hopper/mainloop_fwd_sm80.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,8 @@ struct CollectiveMainloopFwdSm80 {
215215
int const* const leftpad_k = nullptr;
216216
int const* const seqlens_rotary = nullptr;
217217
ElementSAux const* const ptr_S_aux = nullptr;
218+
int cp_world_size;
219+
int cp_rank;
218220
};
219221

220222
// Device side kernel params

hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,9 @@ struct CollectiveMainloopFwdSm90 {
412412
int const* const leftpad_k = nullptr;
413413
int const* const seqlens_rotary = nullptr;
414414
ElementSAux const* const ptr_S_aux = nullptr;
415+
// Context parallelism (CP) parameters
416+
int const cp_world_size = 1;
417+
int const cp_rank = 0;
415418
};
416419

417420
// Device side kernel params
@@ -469,6 +472,8 @@ struct CollectiveMainloopFwdSm90 {
469472
int const* const leftpad_k = nullptr;
470473
int const* const seqlens_rotary = nullptr;
471474
ElementSAux const* const ptr_S_aux = nullptr;
475+
int cp_world_size = 1;
476+
int cp_rank = 0;
472477
};
473478

474479
static Params
@@ -540,7 +545,7 @@ struct CollectiveMainloopFwdSm90 {
540545
return nullptr;
541546
}
542547
}();
543-
548+
544549
auto const shape_Qv_packed = cute::conditional_return<!PackGQA>(
545550
shape_Qv,
546551
make_shape(make_shape(qhead_per_khead, get<0>(shape_Qv)), get<1>(shape_Qv), get<2>(args.shape_K), get<3>(shape_Qv))
@@ -584,7 +589,8 @@ struct CollectiveMainloopFwdSm90 {
584589
args.kv_batch_idx,
585590
args.cu_seqlens_q, args.cu_seqlens_k, args.cu_seqlens_k_new,
586591
args.seqused_q, args.seqused_k, args.leftpad_k, args.seqlens_rotary,
587-
args.ptr_S_aux};
592+
args.ptr_S_aux,
593+
args.cp_world_size, args.cp_rank};
588594
}
589595

590596
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
@@ -999,6 +1005,7 @@ struct CollectiveMainloopFwdSm90 {
9991005
static constexpr int kBlockN = get<1>(TileShape_MNK{});
10001006

10011007
// can't use auto [m_block, ...] = block_coord since structured binding cannot be captured in lambda
1008+
// block index
10021009
int const m_block = get<0>(block_coord);
10031010
int const bidh = get<1>(block_coord);
10041011
int const bidb = get<2>(block_coord);
@@ -1093,7 +1100,8 @@ struct CollectiveMainloopFwdSm90 {
10931100
// But we subtract n_offset for consistency in mask calculations
10941101
flash::Mask<kBlockM, kBlockN, PackGQA, TiledMmaQK> mask(
10951102
thread_idx, seqlen_q, seqlen_k, params.window_size_left, params.window_size_right, 0 - n_offset /*sink_token_length*/,
1096-
params.qhead_per_khead_divmod
1103+
params.qhead_per_khead_divmod,
1104+
params.cp_world_size, params.cp_rank
10971105
);
10981106

10991107
float softcap_val = params.softcap_val;
@@ -1201,6 +1209,7 @@ struct CollectiveMainloopFwdSm90 {
12011209
}
12021210

12031211
if constexpr (IntraWGOverlap) {
1212+
12041213
Tensor tSrS = partition_fragment_C(tiled_mma_qk, select<0, 1>(TileShape_MNK{}));
12051214
consumer_wait(pipeline_k, smem_pipe_read);
12061215
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma_qk, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
@@ -1272,7 +1281,8 @@ struct CollectiveMainloopFwdSm90 {
12721281
};
12731282

12741283
if constexpr (Is_causal || Is_local) { // Separate iterations with causal or local masking
1275-
auto mask_fn = [&](auto& tSrS, int n_block) { mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
1284+
auto mask_fn = [&](auto& tSrS, int n_block) {
1285+
mask.template apply<false /*Seqlenk_mask*/, Is_causal, Is_local>(tSrS, m_block, n_block); };
12761286
int const m_idx_min = !PackGQA ? m_block * kBlockM : params.qhead_per_khead_divmod.divide(m_block * kBlockM);
12771287
// If local, blocking (window_size_right + window_size_left)
12781288
int const n_block_min_causal_local_mask =
@@ -1288,7 +1298,7 @@ struct CollectiveMainloopFwdSm90 {
12881298
int const n_block_min_before_local_mask = !Is_local
12891299
? n_block_min
12901300
: std::max(n_block_min,
1291-
cute::ceil_div(m_idx_max + seqlen_k - seqlen_q - params.window_size_left, kBlockN));
1301+
cute::ceil_div(m_idx_max + params.cp_world_size * seqlen_k - seqlen_q - params.window_size_left, kBlockN));
12921302
auto no_mask_fn = [](auto& tSrS, int n_block) { };
12931303
#pragma unroll 1
12941304
for (; n_block >= n_block_min_before_local_mask; --n_block) {
@@ -1414,7 +1424,7 @@ struct CollectiveMainloopFwdSm90 {
14141424
// Tensor scores_scale = softmax.finalize(v_descale);
14151425
Tensor scores_scale = make_tensor_like(softmax.row_max);
14161426
finalize_dispatch(scores_scale, v_descale);
1417-
1427+
14181428
if constexpr (LargeHeadDimV) {
14191429
cutlass::arch::NamedBarrier::sync(NumMmaThreads, static_cast<uint32_t>(FwdNamedBarriers::PEmpty) /*id*/);
14201430
store_scales(scores_scale, smem_pipe_read.index());

0 commit comments

Comments
 (0)