@@ -272,10 +272,11 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
272272 if (params.is_bf16 ) {
273273 #ifndef FLASHATTENTION_DISABLE_HDIM64
274274 if (params.d <= 64 ) {
275- if (params.dv > 64 && Arch == 90 ) {
275+ if (params.dv > 256 && Arch == 90 ) {
276276 return run_mha_fwd_<Arch, cutlass::bfloat16_t , 64 , 512 , Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
277- }
278- else {
277+ } else if (params.dv > 64 && Arch == 90 ) {
278+ return run_mha_fwd_<Arch, cutlass::bfloat16_t , 64 , 256 , Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
279+ } else {
279280 return run_mha_fwd_<Arch, cutlass::bfloat16_t , 64 , 64 , Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
280281 }
281282 }
@@ -302,10 +303,11 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
302303 #ifndef FLASHATTENTION_DISABLE_FP16
303304 #ifndef FLASHATTENTION_DISABLE_HDIM64
304305 if (params.d <= 64 ) {
305- if (params.dv > 64 && Arch == 90 ) {
306+ if (params.dv > 256 && Arch == 90 ) {
306307 return run_mha_fwd_<Arch, cutlass::half_t , 64 , 512 , Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
307- }
308- else {
308+ } else if (params.dv > 64 && Arch == 90 ) {
309+ return run_mha_fwd_<Arch, cutlass::half_t , 64 , 256 , Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
310+ } else {
309311 return run_mha_fwd_<Arch, cutlass::half_t , 64 , 64 , Split, PagedKVNonTMA, Has_softcap, PackGQA>(params, stream);
310312 }
311313 }
@@ -490,6 +492,15 @@ inline int round_up_headdim(int head_size) {
490492 return 256 ;
491493}
492494
495+ inline int round_up_headdimv (int head_size) {
496+ if (head_size <= 64 ) { return 64 ; }
497+ if (head_size <= 96 ) { return 96 ; }
498+ if (head_size <= 128 ) { return 128 ; }
499+ if (head_size <= 192 ) { return 192 ; }
500+ if (head_size <= 256 ) { return 256 ; }
501+ return 512 ;
502+ }
503+
493504// Only applicable to the case where seqused_k (i.e. cache_seqlens) is available
494505at::Tensor
495506mha_fwd_get_scheduler_metadata (
@@ -534,7 +545,7 @@ mha_fwd_get_scheduler_metadata(
534545 params.d = headdim;
535546 params.dv = headdim_v;
536547 params.d_rounded = round_up_headdim (headdim);
537- params.dv_rounded = round_up_headdim (headdim_v);
548+ params.dv_rounded = headdim_v == headdim ? params. d_rounded : round_up_headdimv (headdim_v);
538549 params.seqlen_knew = max_seqlen_k_new;
539550
540551 bool const is_varlen_q = cu_seqlens_q_.has_value ();
@@ -640,6 +651,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
640651 std::optional<const at::Tensor> &leftpad_k_, // b
641652 std::optional<const at::Tensor> &rotary_cos_, // seqlen_ro x (rotary_dim / 2)
642653 std::optional<const at::Tensor> &rotary_sin_, // seqlen_ro x (rotary_dim / 2)
654+ std::optional<const at::Tensor> &seqlens_rotary_, // b
643655 std::optional<at::Tensor> &q_descale_, // (b, h_k), not (b, h)
644656 std::optional<at::Tensor> &k_descale_, // (b, h_k)
645657 std::optional<at::Tensor> &v_descale_, // (b, h_k)
@@ -823,7 +835,7 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
823835
824836 auto round_multiple = [](int x, int m) { return (x + m - 1 ) / m * m; };
825837 int const head_size_rounded = round_up_headdim (head_size);
826- int const head_size_v_rounded = round_up_headdim (head_size_v);
838+ int const head_size_v_rounded = head_size_v == head_size ? head_size_rounded : round_up_headdimv (head_size_v);
827839 int const seqlen_q_rounded = round_multiple (seqlen_q, 128 );
828840 int const seqlen_k_rounded = round_multiple (seqlen_k, 128 );
829841
@@ -1001,6 +1013,13 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
10011013 params.rotary_cos_ptr = rotary_cos.data_ptr ();
10021014 params.rotary_sin_ptr = rotary_sin.data_ptr ();
10031015 params.is_rotary_interleaved = is_rotary_interleaved;
1016+ if (seqlens_rotary_.has_value ()) {
1017+ at::Tensor seqlens_rotary = seqlens_rotary_.value ();
1018+ CHECK_DEVICE (seqlens_rotary); CHECK_CONTIGUOUS (seqlens_rotary);
1019+ TORCH_CHECK (seqlens_rotary.dtype () == torch::kInt32 , " seqlens_rotary must have dtype torch.int32" );
1020+ CHECK_SHAPE (seqlens_rotary, batch_size);
1021+ params.seqlens_rotary = seqlens_rotary.data_ptr <int >();
1022+ }
10041023 } else {
10051024 params.rotary_dim = 0 ;
10061025 }
@@ -1104,7 +1123,11 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq
11041123 // params.b = 1;
11051124 // params.seqlen_q = total_q;
11061125 // }
1126+ // This will zero out the semaphore if needed
11071127 run_mha_fwd_combine (params, stream, true /* enable_pdl*/ );
1128+ } else if (scheduler_needs_semaphore && params.skip_scheduler_metadata_computation ) {
1129+ // need to zero out the semaphore in this case
1130+ tile_count_semaphore.index ({torch::indexing::Slice (0 , 1 )}).zero_ ();
11081131 }
11091132 } else if (total_q > 0 && num_heads_k > 0 ) {
11101133 // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0.
@@ -1492,7 +1515,6 @@ mha_combine(const at::Tensor &out_partial, // num_splits x batch_size x
14921515 const int seqlen = sizes[2 ];
14931516 const int num_heads = sizes[3 ];
14941517 const int head_size_og = sizes[4 ];
1495- TORCH_CHECK (head_size_og <= 512 , " FlashAttention combine only supports head dimension at most 512" );
14961518 TORCH_CHECK (num_splits <= 256 , " FlashAttention combine only supports num_splits at most 256" );
14971519
14981520 CHECK_SHAPE (out_partial, num_splits, batch_size, seqlen, num_heads, head_size_og);
0 commit comments