Skip to content

Commit a2d483b

Browse files
yugong333jeejeelee
authored andcommitted
Load tuned fused_moe_lora shrink and expand kernel configs separately (vllm-project#27435)
Signed-off-by: Yu Gong <yu3.gong@gmail.com> Co-authored-by: Jee Jee Li <pandaleefree@gmail.com>
1 parent ea33c0d commit a2d483b

File tree

9 files changed

+910
-124
lines changed

9 files changed

+910
-124
lines changed

benchmarks/kernels/benchmark_lora.py

Lines changed: 446 additions & 30 deletions
Large diffs are not rendered by default.

tests/lora/test_fused_moe_lora_kernel.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ def use_fused_moe_lora_kernel(
158158
"BLOCK_SIZE_N": 32,
159159
"BLOCK_SIZE_K": 64,
160160
"GROUP_SIZE_M": 1,
161+
"NUM_WARPS": 4,
162+
"NUM_STAGES": 3,
161163
"SPLIT_K": 1,
162164
}
163165

@@ -182,6 +184,15 @@ def use_fused_moe_lora_kernel(
182184
config["BLOCK_SIZE_N"],
183185
config["BLOCK_SIZE_K"],
184186
config["GROUP_SIZE_M"],
187+
config["NUM_WARPS"],
188+
config["NUM_STAGES"],
189+
config["SPLIT_K"],
190+
config["BLOCK_SIZE_M"],
191+
config["BLOCK_SIZE_N"],
192+
config["BLOCK_SIZE_K"],
193+
config["GROUP_SIZE_M"],
194+
config["NUM_WARPS"],
195+
config["NUM_STAGES"],
185196
config["SPLIT_K"],
186197
mul_routed_weight,
187198
)

vllm/lora/layers/fused_moe.py

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
get_tensor_model_parallel_world_size,
1414
)
1515
from vllm.lora.layers.base import BaseLayerWithLoRA
16+
from vllm.lora.ops.triton_ops.utils import get_lora_op_configs
1617
from vllm.model_executor.layers.fused_moe import FusedMoE
1718
from vllm.model_executor.layers.fused_moe.config import (
1819
_get_config_dtype_str,
@@ -39,6 +40,64 @@ def __init__(self, base_layer: FusedMoE) -> None:
3940
self.device = base_layer.w2_weight.device
4041
self._inject_lora_into_fused_moe()
4142

43+
def _normalize_keys(self, config: dict[str, int | None]) -> dict[str, int | None]:
44+
normalized_config = {}
45+
for key, value in config.items():
46+
if key.islower():
47+
if key.startswith("block_"):
48+
normalized_key = "BLOCK_SIZE_" + key.split("_")[-1].upper()
49+
else:
50+
normalized_key = key.upper()
51+
else:
52+
normalized_key = key
53+
normalized_config[normalized_key] = value
54+
return normalized_config
55+
56+
def _get_lora_moe_configs(
57+
self,
58+
op_prefix: str,
59+
lora_a_stacked: torch.Tensor,
60+
lora_b_stacked: torch.Tensor,
61+
num_slices: int,
62+
M: int,
63+
layer: FusedMoE,
64+
top_k: int,
65+
config_dtype: str,
66+
):
67+
if envs.VLLM_TUNED_CONFIG_FOLDER:
68+
shrink_config = get_lora_op_configs(
69+
op_type=f"fused_moe_lora_{op_prefix}_shrink",
70+
max_loras=lora_a_stacked.shape[0],
71+
batch=M,
72+
hidden_size=lora_a_stacked.shape[-1],
73+
rank=lora_a_stacked.shape[-2],
74+
num_slices=num_slices,
75+
moe_intermediate_size=lora_b_stacked.shape[-2],
76+
)
77+
expand_config = get_lora_op_configs(
78+
op_type=f"fused_moe_lora_{op_prefix}_expand",
79+
max_loras=lora_a_stacked.shape[0],
80+
batch=M,
81+
hidden_size=lora_a_stacked.shape[-1],
82+
rank=lora_a_stacked.shape[-2],
83+
num_slices=num_slices,
84+
moe_intermediate_size=lora_b_stacked.shape[-2],
85+
)
86+
else: # fall back to the default config
87+
get_config_func = functools.partial(
88+
try_get_optimal_moe_config,
89+
layer.w13_weight.size(),
90+
layer.w2_weight.size(),
91+
top_k,
92+
config_dtype,
93+
block_shape=layer.quant_method.moe_quant_config.block_shape,
94+
)
95+
shrink_config = get_config_func(M)
96+
expand_config = get_config_func(M)
97+
shrink_config = self._normalize_keys(shrink_config)
98+
expand_config = self._normalize_keys(expand_config)
99+
return shrink_config, expand_config
100+
42101
def _inject_lora_into_fused_moe(self):
43102
moe_state_dict = {}
44103
top_k = self.base_layer.top_k
@@ -90,25 +149,27 @@ def wrapper(*args, **kwargs):
90149
num_tokens = hidden_states.size(0)
91150
M = min(num_tokens, CHUNK_SIZE)
92151

93-
get_config_func = functools.partial(
94-
try_get_optimal_moe_config,
95-
layer.w13_weight.size(),
96-
layer.w2_weight.size(),
97-
top_k,
98-
config_dtype,
99-
block_shape=layer.quant_method.moe_quant_config.block_shape,
152+
shrink_config, expand_config = self._get_lora_moe_configs(
153+
op_prefix="w13",
154+
lora_a_stacked=self.w1_lora_a_stacked,
155+
lora_b_stacked=self.w1_lora_b_stacked,
156+
num_slices=2,
157+
M=M,
158+
layer=layer,
159+
top_k=top_k,
160+
config_dtype=config_dtype,
100161
)
101162

163+
# get the block size of m from customized config or default config
102164
max_loras = self.w1_lora_a_stacked.shape[0]
103-
config = get_config_func(M)
104165
(
105166
sorted_token_ids_lora,
106167
expert_ids_lora,
107168
num_tokens_post_padded_lora,
108169
) = self.punica_wrapper.moe_lora_align_block_size(
109170
curr_topk_ids,
110171
num_tokens,
111-
config["BLOCK_SIZE_M"],
172+
shrink_config["BLOCK_SIZE_M"],
112173
self.base_layer.local_num_experts,
113174
max_loras,
114175
self.adapter_enabled,
@@ -138,7 +199,8 @@ def wrapper(*args, **kwargs):
138199
num_tokens_post_padded_lora,
139200
max_lora_rank,
140201
top_k,
141-
config,
202+
shrink_config, ## pass the shrink config
203+
expand_config, ## pass the expand config
142204
self.adapter_enabled,
143205
)
144206

@@ -164,17 +226,17 @@ def wrapper(*args, **kwargs):
164226
num_tokens = hidden_states.size(0)
165227
M = min(num_tokens, CHUNK_SIZE)
166228

167-
get_config_func = functools.partial(
168-
try_get_optimal_moe_config,
169-
layer.w13_weight.size(),
170-
layer.w2_weight.size(),
171-
top_k,
172-
config_dtype,
173-
block_shape=layer.quant_method.moe_quant_config.block_shape,
229+
shrink_config, expand_config = self._get_lora_moe_configs(
230+
op_prefix="w2",
231+
lora_a_stacked=self.w2_lora_a_stacked,
232+
lora_b_stacked=self.w2_lora_b_stacked,
233+
num_slices=1,
234+
M=M,
235+
layer=layer,
236+
top_k=top_k,
237+
config_dtype=config_dtype,
174238
)
175239

176-
config = get_config_func(M)
177-
178240
sorted_token_ids_lora = moe_state_dict["sorted_token_ids_lora"]
179241
expert_ids_lora = moe_state_dict["expert_ids_lora"]
180242
num_tokens_post_padded_lora = moe_state_dict[
@@ -197,7 +259,8 @@ def wrapper(*args, **kwargs):
197259
num_tokens_post_padded_lora,
198260
max_lora_rank,
199261
top_k,
200-
config,
262+
shrink_config, ## pass the shrink config
263+
expand_config, ## pass the expand config
201264
self.adapter_enabled,
202265
True,
203266
)

vllm/lora/ops/triton_ops/README_TUNING.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,17 @@ For `shrink`, the config file is named as `{gpu_name}_SHRINK.json`, e.g. `NVIDIA
4444

4545
For `expand`, the config fileis named as `{gpu_name}_EXPAND_{add_input}.json`, e.g. `NVIDIA_H200_EXPAND_TRUE.json`.
4646

47+
For `fused_moe_lora_w13_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W13_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W13_SHRINK.json`.
48+
49+
For `fused_moe_lora_w13_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W13_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W13_EXPAND.json`.
50+
51+
For `fused_moe_lora_w2_shrink`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W2_SHRINK.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W2_SHRINK.json`.
52+
53+
For `fused_moe_lora_w2_expand`, the config file is named as `{gpu_name}_FUSED_MOE_LORA_W2_EXPAND.json`, e.g. `NVIDIA_H200_FUSED_MOE_LORA_W2_EXPAND.json`.
54+
4755
The `gpu_name` can be automatically detected by calling `torch.cuda.get_device_name()`
4856

4957
### Json Structure
5058

51-
Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n]`
59+
Optimal kernel configuration files are saved as JSON files with the structure `config_data[max_loras][num_slices][m][k][n][i]`
60+
where `i` is an optional dimension in the `fused_moe_lora` configuration, representing the intermediate size of the MoE layer.

vllm/lora/ops/triton_ops/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from vllm.lora.ops.triton_ops.fused_moe_lora_op import fused_moe_lora
4+
5+
from vllm.lora.ops.triton_ops.fused_moe_lora_op import (
6+
fused_moe_lora,
7+
fused_moe_lora_expand,
8+
fused_moe_lora_shrink,
9+
)
510
from vllm.lora.ops.triton_ops.lora_expand_op import lora_expand
611
from vllm.lora.ops.triton_ops.lora_kernel_metadata import LoRAKernelMeta
712
from vllm.lora.ops.triton_ops.lora_shrink_op import lora_shrink
@@ -11,4 +16,6 @@
1116
"lora_shrink",
1217
"LoRAKernelMeta",
1318
"fused_moe_lora",
19+
"fused_moe_lora_shrink",
20+
"fused_moe_lora_expand",
1421
]

0 commit comments

Comments
 (0)