Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,32 +125,33 @@ class MLA {
}

static void set_split_kv (KernelArguments& args) {
// printf("set_split_kv start");
if (args.split_kv >= 1) return;
auto [H, K, D, B] = args.problem_shape;
// std::cout << H << " " << K << " " << D << " " << B << "\n";
int sm_count = args.hw_info.sm_count;
// printf(" sm_count = %d\n", sm_count);
int max_splits = ceil_div(K, 128);
max_splits = min(16, max_splits);

// TODO: This avoids a hang when the batch size larger than 1 and
// there is more than 1 kv_splits.
// Discuss with NVIDIA how this can be fixed.
if (B > 1) {
max_splits = min(1, max_splits);

// Ratio-based heuristic for max_splits: seq_length_k / batch_size
// K is sequence length in tokens, convert to "k" units (1024 tokens)
float seq_length_k = static_cast<float>(K) / 1024.0f;
float ratio = seq_length_k / static_cast<float>(B);

int max_splits;
if (ratio >= 2.5f) {
max_splits = 8;
} else if (ratio >= 1.2f) {
max_splits = 4;
} else if (ratio >= 0.5f) {
max_splits = 2;
} else {
max_splits = 1;
}
// printf(" max_splits = %d\n", max_splits);

// Wave-aware scheduling: ensure integer number of waves in K dimension
int sms_per_batch = max(1, sm_count / B);
// printf(" sms_per_batch = %d\n", sms_per_batch);
int split_heur = min(max_splits, sms_per_batch);
int waves = ceil_div(B * split_heur, sm_count);
int k_waves = ceil_div(max_splits, split_heur);
int split_wave_aware = ceil_div(max_splits, k_waves);
args.split_kv = split_wave_aware;
// printf(" args.split_kv = %d\n", args.split_kv);

}

/// Determines whether the GEMM can execute the given problem.
Expand Down