We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 02d4b85 commit 2975c3fCopy full SHA for 2975c3f
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
@@ -133,6 +133,14 @@ class MLA {
133
// printf(" sm_count = %d\n", sm_count);
134
int max_splits = ceil_div(K, 128);
135
max_splits = min(16, max_splits);
136
+
137
+ // TODO: This avoids a hang when the batch size larger than 1 and
138
+ // there is more than 4 kv_splits.
139
+ // Discuss with NVIDIA how this can be fixed.
140
+ if (B > 1) {
141
+ max_splits = min(2, max_splits);
142
+ }
143
144
// printf(" max_splits = %d\n", max_splits);
145
int sms_per_batch = max(1, sm_count / B);
146
// printf(" sms_per_batch = %d\n", sms_per_batch);
0 commit comments