Skip to content

Commit cc65153

Browse files
alexm-redhatmgoin
authored andcommitted
[Bugfix][B200] Fix cutlass_mla hang (vllm-project#24966)
Signed-off-by: Alexander Matveev <amatveev@redhat.com> Co-authored-by: Michael Goin <mgoin64@gmail.com>
1 parent 44658c7 commit cc65153

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,14 @@ class MLA {
133133
// printf(" sm_count = %d\n", sm_count);
134134
int max_splits = ceil_div(K, 128);
135135
max_splits = min(16, max_splits);
136+
137+
// TODO: This avoids a hang when the batch size larger than 1 and
138+
// there is more than 4 kv_splits.
139+
// Discuss with NVIDIA how this can be fixed.
140+
if (B > 1) {
141+
max_splits = min(2, max_splits);
142+
}
143+
136144
// printf(" max_splits = %d\n", max_splits);
137145
int sms_per_batch = max(1, sm_count / B);
138146
// printf(" sms_per_batch = %d\n", sms_per_batch);

0 commit comments

Comments
 (0)