Skip to content

Commit 2975c3f

Browse files
committed
[Bugfix] Fix hanging issue with cutlass_mla by limiting the kv_split for larger batch size
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
1 parent 02d4b85 commit 2975c3f

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ class MLA {
133133
// printf(" sm_count = %d\n", sm_count);
134134
int max_splits = ceil_div(K, 128);
135135
max_splits = min(16, max_splits);
136+
137+
// TODO: This avoids a hang when the batch size larger than 1 and
138+
// there is more than 4 kv_splits.
139+
// Discuss with NVIDIA how this can be fixed.
140+
if (B > 1) {
141+
max_splits = min(2, max_splits);
142+
}
143+
136144
// printf(" max_splits = %d\n", max_splits);
137145
int sms_per_batch = max(1, sm_count / B);
138146
// printf(" sms_per_batch = %d\n", sms_per_batch);

0 commit comments

Comments
 (0)