Skip to content

Commit 37f1a09

Browse files
committed
[Bugfix] Fix hanging issue with cutlass_mla by limiting the kv_split for larger batch size
Signed-off-by: Alexander Matveev <amatveev@redhat.com>
1 parent 94b03f8 commit 37f1a09

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,13 @@ class MLA {
133133
// printf(" sm_count = %d\n", sm_count);
134134
int max_splits = ceil_div(K, 128);
135135
max_splits = min(16, max_splits);
136+
137+
// TODO: This avoids a hang when the batch size is large and there is more
138+
// than 4 kv_splits. Discuss with NVIDIA how this can be fixed.
139+
if (B >= 8) {
140+
max_splits = min(2, max_splits);
141+
}
142+
136143
// printf(" max_splits = %d\n", max_splits);
137144
int sms_per_batch = max(1, sm_count / B);
138145
// printf(" sms_per_batch = %d\n", sms_per_batch);

0 commit comments

Comments
 (0)