diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 8d126197f83e..adc5c9dce5e8 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -41,6 +41,7 @@ def _lora_shrink_kernel( BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, SPLIT_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, SLICE_NUM: tl.constexpr, ): cta_n_num = tl.cdiv(N, BLOCK_N) @@ -48,8 +49,16 @@ def _lora_shrink_kernel( pid_sk_m_n = tl.program_id(axis=0) pid_sk = pid_sk_m_n % SPLIT_K - pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num - pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num + + pid_m_n = pid_sk_m_n // SPLIT_K + num_pid_in_group = GROUP_SIZE_M * cta_n_num + group_id = pid_m_n // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(cta_m_num - first_pid_m, GROUP_SIZE_M) + + # Column-major ordering within groups for better cache reuse + pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m) + pid_n = (pid_m_n % num_pid_in_group) // group_size_m slice_id = tl.program_id(axis=1) lora_idx = tl.program_id(axis=2) @@ -194,6 +203,7 @@ def _lora_shrink( NUM_WARPS = kernel_config["num_warps"] NUM_STAGES = kernel_config["num_stages"] NUM_CTAS = kernel_config["num_ctas"] + GROUP_SIZE_M = kernel_config.get("group_size_m", 8) EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore # TODO (varun): This grid formulation maximizes parallelization at the @@ -233,6 +243,7 @@ def _lora_shrink( BLOCK_K, EVEN_K, SPLIT_K, + GROUP_SIZE_M, NUM_SLICES, num_warps=NUM_WARPS, num_ctas=NUM_CTAS, diff --git a/vllm/lora/ops/triton_ops/utils.py b/vllm/lora/ops/triton_ops/utils.py index 9ffb6dc3d85e..368c5037d2e4 100644 --- a/vllm/lora/ops/triton_ops/utils.py +++ b/vllm/lora/ops/triton_ops/utils.py @@ -199,6 +199,7 @@ def get_lora_op_configs( "split_k": 64 if batch < 128 else 8, "num_warps": 4, "num_ctas": 1, + "group_size_m": 8, "num_stages": 2, "max_nreg": None, }