Skip to content

Commit 6ddae74

Browse files
[LoRA] Lora shrink swizzle (#27694)
Signed-off-by: li2haipeng <44383182+li2haipeng@users.noreply.github.com> Signed-off-by: Haipeng Li <li2haipeng@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent b13a447 commit 6ddae74

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

vllm/lora/ops/triton_ops/lora_shrink_op.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,24 @@ def _lora_shrink_kernel(
4141
BLOCK_K: tl.constexpr,
4242
EVEN_K: tl.constexpr,
4343
SPLIT_K: tl.constexpr,
44+
GROUP_SIZE_M: tl.constexpr,
4445
SLICE_NUM: tl.constexpr,
4546
):
4647
cta_n_num = tl.cdiv(N, BLOCK_N)
4748
cta_m_num = tl.cdiv(M, BLOCK_M)
4849

4950
pid_sk_m_n = tl.program_id(axis=0)
5051
pid_sk = pid_sk_m_n % SPLIT_K
51-
pid_m = (pid_sk_m_n // SPLIT_K) % cta_m_num
52-
pid_n = pid_sk_m_n // (SPLIT_K * cta_m_num) % cta_n_num
52+
53+
pid_m_n = pid_sk_m_n // SPLIT_K
54+
num_pid_in_group = GROUP_SIZE_M * cta_n_num
55+
group_id = pid_m_n // num_pid_in_group
56+
first_pid_m = group_id * GROUP_SIZE_M
57+
group_size_m = min(cta_m_num - first_pid_m, GROUP_SIZE_M)
58+
59+
# Column-major ordering within groups for better cache reuse
60+
pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m)
61+
pid_n = (pid_m_n % num_pid_in_group) // group_size_m
5362

5463
slice_id = tl.program_id(axis=1)
5564
lora_idx = tl.program_id(axis=2)
@@ -194,6 +203,7 @@ def _lora_shrink(
194203
NUM_WARPS = kernel_config["num_warps"]
195204
NUM_STAGES = kernel_config["num_stages"]
196205
NUM_CTAS = kernel_config["num_ctas"]
206+
GROUP_SIZE_M = kernel_config.get("group_size_m", 8)
197207
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 # type: ignore
198208

199209
# TODO (varun): This grid formulation maximizes parallelization at the
@@ -233,6 +243,7 @@ def _lora_shrink(
233243
BLOCK_K,
234244
EVEN_K,
235245
SPLIT_K,
246+
GROUP_SIZE_M,
236247
NUM_SLICES,
237248
num_warps=NUM_WARPS,
238249
num_ctas=NUM_CTAS,

vllm/lora/ops/triton_ops/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def get_lora_op_configs(
199199
"split_k": 64 if batch < 128 else 8,
200200
"num_warps": 4,
201201
"num_ctas": 1,
202+
"group_size_m": 8,
202203
"num_stages": 2,
203204
"max_nreg": None,
204205
}

0 commit comments

Comments
 (0)