@@ -41,15 +41,24 @@ def _lora_shrink_kernel(
4141 BLOCK_K : tl .constexpr ,
4242 EVEN_K : tl .constexpr ,
4343 SPLIT_K : tl .constexpr ,
44+ GROUP_SIZE_M : tl .constexpr ,
4445 SLICE_NUM : tl .constexpr ,
4546):
4647 cta_n_num = tl .cdiv (N , BLOCK_N )
4748 cta_m_num = tl .cdiv (M , BLOCK_M )
4849
4950 pid_sk_m_n = tl .program_id (axis = 0 )
5051 pid_sk = pid_sk_m_n % SPLIT_K
51- pid_m = (pid_sk_m_n // SPLIT_K ) % cta_m_num
52- pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num ) % cta_n_num
52+
53+ pid_m_n = pid_sk_m_n // SPLIT_K
54+ num_pid_in_group = GROUP_SIZE_M * cta_n_num
55+ group_id = pid_m_n // num_pid_in_group
56+ first_pid_m = group_id * GROUP_SIZE_M
57+ group_size_m = min (cta_m_num - first_pid_m , GROUP_SIZE_M )
58+
59+ # Column-major ordering within groups for better cache reuse
60+ pid_m = first_pid_m + ((pid_m_n % num_pid_in_group ) % group_size_m )
61+ pid_n = (pid_m_n % num_pid_in_group ) // group_size_m
5362
5463 slice_id = tl .program_id (axis = 1 )
5564 lora_idx = tl .program_id (axis = 2 )
@@ -194,6 +203,7 @@ def _lora_shrink(
194203 NUM_WARPS = kernel_config ["num_warps" ]
195204 NUM_STAGES = kernel_config ["num_stages" ]
196205 NUM_CTAS = kernel_config ["num_ctas" ]
206+ GROUP_SIZE_M = kernel_config .get ("group_size_m" , 8 )
197207 EVEN_K = K % (BLOCK_K * SPLIT_K ) == 0 # type: ignore
198208
199209 # TODO (varun): This grid formulation maximizes parallelization at the
@@ -233,6 +243,7 @@ def _lora_shrink(
233243 BLOCK_K ,
234244 EVEN_K ,
235245 SPLIT_K ,
246+ GROUP_SIZE_M ,
236247 NUM_SLICES ,
237248 num_warps = NUM_WARPS ,
238249 num_ctas = NUM_CTAS ,
0 commit comments