Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions vllm/lora/ops/triton_ops/lora_shrink_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,24 @@ def _lora_shrink_kernel(
BLOCK_K: tl.constexpr,
EVEN_K: tl.constexpr,
SPLIT_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
SLICE_NUM: tl.constexpr,
):
cta_n_num = tl.cdiv(N, BLOCK_N)
cta_m_num = tl.cdiv(M, BLOCK_M)

pid_sk_m_n = tl.program_id(axis=0)
pid_sk = pid_sk_m_n % SPLIT_K
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num

pid_m_n = pid_sk_m_n // SPLIT_K
num_pid_in_group = GROUP_SIZE_M * cta_n_num
group_id = pid_m_n // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(cta_m_num - first_pid_m, GROUP_SIZE_M)

# Column-major ordering within groups for better cache reuse
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
pid_n = (pid_m_n % num_pid_in_group) // group_size_m

slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
Expand Down Expand Up @@ -194,6 +203,7 @@ def _lora_shrink(
NUM_WARPS = kernel_config["num_warps"]
NUM_STAGES = kernel_config["num_stages"]
NUM_CTAS = kernel_config["num_ctas"]
GROUP_SIZE_M = kernel_config.get("group_size_m", 8)
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore

# TODO (varun): This grid formulation maximizes parallelization at the
Expand Down Expand Up @@ -233,6 +243,7 @@ def _lora_shrink(
BLOCK_K,
EVEN_K,
SPLIT_K,
GROUP_SIZE_M,
NUM_SLICES,
num_warps=NUM_WARPS,
num_ctas=NUM_CTAS,
Expand Down
1 change: 1 addition & 0 deletions vllm/lora/ops/triton_ops/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def get_lora_op_configs(
"split_k": 64 if batch < 128 else 8,
"num_warps": 4,
"num_ctas": 1,
"group_size_m": 8,
"num_stages": 2,
"max_nreg": None,
}
Expand Down