Skip to content

Commit 9aa4055

Browse files
MatthewBonannixuebwang-amd
authored andcommitted
[Attention] Tune CUTLASS MLA num_splits (vllm-project#26846)
Signed-off-by: Matthew Bonanni <mbonanni@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent ead97f4 commit 9aa4055

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -125,32 +125,37 @@ class MLA {
125125
}
126126

127127
static void set_split_kv (KernelArguments& args) {
128-
// printf("set_split_kv start");
129128
if (args.split_kv >= 1) return;
130129
auto [H, K, D, B] = args.problem_shape;
131-
// std::cout << H << " " << K << " " << D << " " << B << "\n";
132130
int sm_count = args.hw_info.sm_count;
133-
// printf(" sm_count = %d\n", sm_count);
134-
int max_splits = ceil_div(K, 128);
135-
max_splits = min(16, max_splits);
136-
137-
// TODO: This avoids a hang when the batch size larger than 1 and
138-
// there is more than 1 kv_splits.
139-
// Discuss with NVIDIA how this can be fixed.
140-
if (B > 1) {
141-
max_splits = min(1, max_splits);
131+
float seq_length_k = static_cast<float>(K) / 1024.0f;
132+
int max_splits = 1;
133+
134+
if (B <= 4 && seq_length_k >= 16) {
135+
max_splits = 16;
136+
}
137+
else if (B <= 8 && seq_length_k >= 4) {
138+
max_splits = 8;
139+
}
140+
else if ((B <= 16 && seq_length_k >= 8) ||
141+
(B == 48 && seq_length_k >= 32)) {
142+
max_splits = 4;
143+
}
144+
else if ((B <= 32 && seq_length_k >= 16) ||
145+
(B == 96 && seq_length_k >= 16)) {
146+
max_splits = 2;
142147
}
143-
144-
// printf(" max_splits = %d\n", max_splits);
148+
else {
149+
max_splits = 1;
150+
}
151+
152+
// Wave-aware scheduling: ensure integer number of waves in K dimension
145153
int sms_per_batch = max(1, sm_count / B);
146-
// printf(" sms_per_batch = %d\n", sms_per_batch);
147154
int split_heur = min(max_splits, sms_per_batch);
148155
int waves = ceil_div(B * split_heur, sm_count);
149156
int k_waves = ceil_div(max_splits, split_heur);
150157
int split_wave_aware = ceil_div(max_splits, k_waves);
151158
args.split_kv = split_wave_aware;
152-
// printf(" args.split_kv = %d\n", args.split_kv);
153-
154159
}
155160

156161
/// Determines whether the GEMM can execute the given problem.

0 commit comments

Comments
 (0)