We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 94b03f8 commit 37f1a09Copy full SHA for 37f1a09
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
@@ -133,6 +133,13 @@ class MLA {
133
// printf(" sm_count = %d\n", sm_count);
134
int max_splits = ceil_div(K, 128);
135
max_splits = min(16, max_splits);
136
+
137
+ // TODO: This avoids a hang when the batch size is large and there is more
138
+ // than 4 kv_splits. Discuss with NVIDIA how this can be fixed.
139
+ if (B >= 8) {
140
+ max_splits = min(2, max_splits);
141
+ }
142
143
// printf(" max_splits = %d\n", max_splits);
144
int sms_per_batch = max(1, sm_count / B);
145
// printf(" sms_per_batch = %d\n", sms_per_batch);
0 commit comments