@@ -125,32 +125,37 @@ class MLA {
125125 }
126126
127127 static void set_split_kv (KernelArguments& args) {
128- // printf("set_split_kv start");
129128 if (args.split_kv >= 1 ) return ;
130129 auto [H, K, D, B] = args.problem_shape ;
131- // std::cout << H << " " << K << " " << D << " " << B << "\n";
132130 int sm_count = args.hw_info .sm_count ;
133- // printf(" sm_count = %d\n", sm_count);
134- int max_splits = ceil_div (K, 128 );
135- max_splits = min (16 , max_splits);
136-
137- // TODO: This avoids a hang when the batch size larger than 1 and
138- // there is more than 1 kv_splits.
139- // Discuss with NVIDIA how this can be fixed.
140- if (B > 1 ) {
141- max_splits = min (1 , max_splits);
131+ float seq_length_k = static_cast <float >(K) / 1024 .0f ;
132+ int max_splits = 1 ;
133+
134+ if (B <= 4 && seq_length_k >= 16 ) {
135+ max_splits = 16 ;
136+ }
137+ else if (B <= 8 && seq_length_k >= 4 ) {
138+ max_splits = 8 ;
139+ }
140+ else if ((B <= 16 && seq_length_k >= 8 ) ||
141+ (B == 48 && seq_length_k >= 32 )) {
142+ max_splits = 4 ;
143+ }
144+ else if ((B <= 32 && seq_length_k >= 16 ) ||
145+ (B == 96 && seq_length_k >= 16 )) {
146+ max_splits = 2 ;
142147 }
143-
144- // printf(" max_splits = %d\n", max_splits);
148+ else {
149+ max_splits = 1 ;
150+ }
151+
152+ // Wave-aware scheduling: ensure integer number of waves in K dimension
145153 int sms_per_batch = max (1 , sm_count / B);
146- // printf(" sms_per_batch = %d\n", sms_per_batch);
147154 int split_heur = min (max_splits, sms_per_batch);
148155 int waves = ceil_div (B * split_heur, sm_count);
149156 int k_waves = ceil_div (max_splits, split_heur);
150157 int split_wave_aware = ceil_div (max_splits, k_waves);
151158 args.split_kv = split_wave_aware;
152- // printf(" args.split_kv = %d\n", args.split_kv);
153-
154159 }
155160
156161 // / Determines whether the GEMM can execute the given problem.
0 commit comments