From 2975c3fca9b55c857c95b01c3947ee352803ca40 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Tue, 16 Sep 2025 06:40:28 -0700 Subject: [PATCH] [Bugfix] Fix hanging issue with cutlass_mla by limiting the kv_split for larger batch size Signed-off-by: Alexander Matveev --- csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp index 95e32559cd54..fbbc2e588c32 100644 --- a/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp +++ b/csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp @@ -133,6 +133,14 @@ class MLA { // printf(" sm_count = %d\n", sm_count); int max_splits = ceil_div(K, 128); max_splits = min(16, max_splits); + + // TODO: This avoids a hang when the batch size larger than 1 and + // there is more than 4 kv_splits. + // Discuss with NVIDIA how this can be fixed. + if (B > 1) { + max_splits = min(2, max_splits); + } + // printf(" max_splits = %d\n", max_splits); int sms_per_batch = max(1, sm_count / B); // printf(" sms_per_batch = %d\n", sms_per_batch);