Skip to content

Commit f167469

Browse files
authored
Merge pull request vllm-project#2 from gnovack/marlin_experts_mxfp4
enable early exit for fused_moe_lora
2 parents 055486e + 4e80855 commit f167469

File tree

8 files changed

+82
-10
lines changed

8 files changed

+82
-10
lines changed

csrc/moe/moe_lora_align_sum_kernels.cu

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,15 @@ __global__ void moe_lora_align_sum_kernel(
3333
int64_t block_size, int num_experts, int max_loras, size_t numel,
3434
int max_num_tokens_padded, int max_num_m_blocks,
3535
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
36-
int topk_num, int32_t* total_tokens_post_pad) {
36+
int topk_num, int32_t* total_tokens_post_pad, int32_t* num_tokens_per_lora, int32_t* adapter_enabled) {
3737
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
3838
const size_t start_idx = threadIdx.x * tokens_per_thread;
3939

4040
int lora_id = blockIdx.x;
41+
if (adapter_enabled[lora_id] * num_tokens_per_lora[lora_id] == 0) {
42+
return;
43+
}
44+
4145
extern __shared__ int32_t shared_mem[];
4246
int32_t* cumsum = shared_mem;
4347
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
@@ -124,9 +128,10 @@ void moe_lora_align_block_size(torch::Tensor topk_ids,
124128
int64_t max_loras,
125129
torch::Tensor sorted_token_ids,
126130
torch::Tensor expert_ids,
127-
torch::Tensor num_tokens_post_pad) {
131+
torch::Tensor num_tokens_post_pad,
132+
torch::Tensor num_tokens_per_lora,
133+
torch::Tensor adapter_enabled) {
128134
const int topk_num = topk_ids.size(1);
129-
130135
int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1);
131136
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size);
132137
int max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size);
@@ -160,6 +165,7 @@ void moe_lora_align_block_size(torch::Tensor topk_ids,
160165
max_loras, topk_ids.numel(), max_num_tokens_padded,
161166
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
162167
expert_ids.data_ptr<int32_t>(), topk_num,
163-
num_tokens_post_pad.data_ptr<int32_t>());
168+
num_tokens_post_pad.data_ptr<int32_t>(), num_tokens_per_lora.data_ptr<int32_t>(),
169+
adapter_enabled.data_ptr<int32_t>());
164170
});
165171
}

csrc/moe/moe_ops.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ void moe_lora_align_block_size(torch::Tensor topk_ids,
1919
int64_t max_loras,
2020
torch::Tensor sorted_token_ids,
2121
torch::Tensor expert_ids,
22-
torch::Tensor num_tokens_post_pad);
22+
torch::Tensor num_tokens_post_pad,
23+
torch::Tensor num_tokens_per_lora,
24+
torch::Tensor adapter_enabled);
2325
#ifndef USE_ROCM
2426
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
2527
torch::Tensor b_qweight, torch::Tensor b_scales,

csrc/moe/torch_bindings.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
3131
" int block_size, int max_loras, "
3232
" Tensor !sorted_token_ids,"
3333
" Tensor !experts_ids,"
34-
" Tensor !num_tokens_post_pad) -> () ");
34+
" Tensor !num_tokens_post_pad,"
35+
" Tensor !num_tokens_per_lora,"
36+
" Tensor !adapter_enabled) -> () ");
3537
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
3638

3739
#ifndef USE_ROCM

vllm/_custom_ops.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1793,6 +1793,9 @@ def moe_align_block_size(
17931793
def moe_lora_align_block_size(
17941794
topk_ids: torch.Tensor,
17951795
token_lora_mapping: torch.Tensor,
1796+
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
1797+
no_lora_flag_cpu: torch.Tensor, # shape [1]
1798+
adapter_enabled: torch.Tensor, # shape [max-loras]
17961799
num_experts: int,
17971800
block_size: int,
17981801
max_loras: int,
@@ -1809,6 +1812,8 @@ def moe_lora_align_block_size(
18091812
sorted_token_ids,
18101813
experts_ids,
18111814
num_tokens_post_pad,
1815+
num_tokens_per_lora,
1816+
adapter_enabled,
18121817
)
18131818

18141819

vllm/lora/layers/fused_moe.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,10 @@ def wrapper(*args, **kwargs):
7474
global_num_experts = layer._lora["global_num_experts"]
7575
expert_map = layer._lora["expert_map"]
7676

77-
(token_lora_mapping, _, _, _, _,
78-
_) = layer.punica_wrapper.token_mapping_meta.meta_args(
77+
(token_lora_mapping, _, num_tokens_per_lora, _, _,
78+
no_lora_flag_cpu) = layer.punica_wrapper.token_mapping_meta.meta_args(
7979
hidden_states.size(0))
80+
8081
config_dtype = _get_config_dtype_str(use_fp8_w8a8=False,
8182
use_int8_w8a16=False,
8283
use_int4_w4a16=False,
@@ -99,7 +100,8 @@ def wrapper(*args, **kwargs):
99100
config = get_config_func(M)
100101
(sorted_token_ids_lora, expert_ids_lora,
101102
num_tokens_post_padded_lora) = (moe_lora_align_block_size(
102-
curr_topk_ids, token_lora_mapping, config['BLOCK_SIZE_M'],
103+
curr_topk_ids, token_lora_mapping, num_tokens_per_lora, no_lora_flag_cpu,
104+
layer.adapter_enabled, config['BLOCK_SIZE_M'],
103105
global_num_experts, curr_topk_ids.shape[-1], expert_map))
104106

105107
layer._lora["sorted_token_ids_lora"] = sorted_token_ids_lora
@@ -132,6 +134,7 @@ def wrapper(*args, **kwargs):
132134
max_lora_rank,
133135
top_k,
134136
config,
137+
layer.adapter_enabled,
135138
)
136139

137140
result = func(*args, **kwargs)
@@ -191,7 +194,7 @@ def wrapper(*args, **kwargs):
191194
intermediate_cache3, intermediate_cache2,
192195
[w2_lora_a_stacked], [w2_lora_b_stacked], topk_weights,
193196
sorted_token_ids_lora, expert_ids_lora,
194-
num_tokens_post_padded_lora, max_lora_rank, top_k, config,
197+
num_tokens_post_padded_lora, max_lora_rank, top_k, config, layer.adapter_enabled,
195198
True)
196199

197200
result = func(*args, **kwargs)
@@ -226,6 +229,8 @@ def create_lora_weights(
226229
model_config: Optional[PretrainedConfig] = None,
227230
) -> None:
228231
"""Initializes lora matrices."""
232+
self.adapter_enabled = torch.tensor([0] * (max_loras+1), dtype=torch.int, device=self.device)
233+
229234
self.w1_lora_a_stacked = torch.zeros(
230235
(
231236
max_loras,
@@ -288,6 +293,9 @@ def create_lora_weights(
288293
dtype=lora_config.lora_dtype,
289294
device=self.device,
290295
)
296+
297+
# flags to track which LoRAs have MoE adapters
298+
self.base_layer.adapter_enabled = self.adapter_enabled
291299

292300
self.base_layer.w1_lora_a_stacked = self.w1_lora_a_stacked
293301
self.base_layer.w1_lora_b_stacked = self.w1_lora_b_stacked
@@ -324,6 +332,8 @@ def reset_lora(self, index: int):
324332
self.w3_lora_b_stacked[index] = 0
325333
self.w2_lora_a_stacked[index] = 0
326334
self.w2_lora_b_stacked[index] = 0
335+
336+
self.adapter_enabled[index] = 0
327337

328338
def set_lora(
329339
self,
@@ -334,6 +344,9 @@ def set_lora(
334344
bias: Optional[torch.Tensor] = None,
335345
):
336346
"""Overwrites lora tensors at index."""
347+
348+
self.adapter_enabled[index] = 1
349+
337350
for eid in range(len(lora_a) // 3):
338351
w1_lora_a = lora_a[eid * 3]
339352
w2_lora_a = lora_a[eid * 3 + 1]

vllm/lora/ops/triton_ops/fused_moe_lora_op.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def _fused_moe_lora_kernel(
4747
EM,
4848
num_valid_tokens,
4949
num_experts,
50+
lora_ids,
51+
adapter_enabled,
5052
# The stride variables represent how much to increase the ptr by when
5153
# moving by 1 element in a particular dimension. E.g. `stride_am` is
5254
# how much to increase `a_ptr` by to get the element one row down
@@ -78,6 +80,12 @@ def _fused_moe_lora_kernel(
7880
slice_id = tl.program_id(axis=1)
7981
lora_idx = tl.program_id(axis=2)
8082

83+
lora_id = tl.load(lora_ids + lora_idx)
84+
moe_enabled = tl.load(adapter_enabled + lora_idx)
85+
if lora_id == -1 or moe_enabled == 0:
86+
# Early exit for the no-lora case.
87+
return
88+
8189
# calculate pid_m,pid_n
8290
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
8391
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
@@ -160,6 +168,13 @@ def _fused_moe_lora(
160168
num_tokens_post_padded: torch.Tensor,
161169
max_lora_rank: int,
162170
top_k_num: int,
171+
token_lora_mapping: torch.Tensor, # shape [num_tokens]
172+
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
173+
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
174+
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
175+
lora_ids: torch.Tensor, # shape [max-loras + 1]
176+
no_lora_flag_cpu: torch.Tensor, # shape [1]
177+
adapter_enabled: torch.Tensor, # shape [max-loras]
163178
# config:Optional[dict[str, Any]],
164179
block_size_m:int,
165180
block_size_n:int,
@@ -183,6 +198,12 @@ def _fused_moe_lora(
183198
config (_type_): _description_
184199
intermediate_cache1 (torch.Tensor): _description_
185200
"""
201+
202+
assert no_lora_flag_cpu.numel() == 1
203+
if no_lora_flag_cpu.item():
204+
# None of the inputs require LoRA
205+
return
206+
186207
assert len(lora_a_stacked) == len(lora_b_stacked)
187208
device = qcurr_hidden_states.device
188209
num_slices = len(lora_a_stacked)
@@ -242,6 +263,8 @@ def _fused_moe_lora(
242263
EM,
243264
num_tokens,
244265
num_experts,
266+
lora_ids,
267+
adapter_enabled,
245268
qcurr_hidden_states.stride(0),
246269
qcurr_hidden_states.stride(1),
247270
w1_lora_a_stacked.stride(0),
@@ -287,6 +310,8 @@ def _fused_moe_lora(
287310
EM,
288311
num_tokens,
289312
num_experts,
313+
lora_ids,
314+
adapter_enabled,
290315
a_intermediate_cache1.stride(1),
291316
a_intermediate_cache1.stride(2),
292317
w1_lora_b_stacked.stride(0),
@@ -324,6 +349,13 @@ def _fused_moe_lora_fake(
324349
block_size_n:int,
325350
block_size_k:int,
326351
group_size_m:int,
352+
token_lora_mapping: torch.Tensor, # shape [num_tokens]
353+
token_indices_sorted_by_lora_ids: torch.Tensor, # shape [num_tokens]
354+
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
355+
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
356+
lora_ids: torch.Tensor, # shape [max-loras + 1]
357+
no_lora_flag_cpu: torch.Tensor, # shape [1]
358+
no_moe_lora_flag_cpu: torch.Tensor,
327359
mul_routed_weight:bool=False,
328360
) -> None:
329361
return

vllm/lora/punica_wrapper/punica_gpu.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ def add_shrink(
9292
scale (float): Scaling factor for the operation
9393
"""
9494

95+
# note @gnovack - force input to be contiguous to support eager mode
96+
x = x.contiguous()
97+
9598
x = x.view(-1, x.shape[-1])
9699
lora_shrink(
97100
x,
@@ -317,6 +320,7 @@ def add_lora_fused_moe(
317320
max_lora_rank: int,
318321
top_k_num: int,
319322
config,
323+
adapter_enabled: torch.Tensor,
320324
mul_routed_weight=False,
321325
):
322326
fused_moe_lora(
@@ -330,6 +334,8 @@ def add_lora_fused_moe(
330334
num_tokens_post_padded,
331335
max_lora_rank,
332336
top_k_num,
337+
*self.token_mapping_meta.meta_args(x.size(0)),
338+
adapter_enabled,
333339
config["BLOCK_SIZE_M"],
334340
config["BLOCK_SIZE_N"],
335341
config["BLOCK_SIZE_K"],

vllm/model_executor/layers/fused_moe/moe_align_block_size.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,9 @@ def moe_align_block_size(
8989
def moe_lora_align_block_size(
9090
topk_ids: torch.Tensor,
9191
token_lora_mapping: torch.Tensor,
92+
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
93+
no_lora_flag_cpu: torch.Tensor, # shape [1]
94+
adapter_enabled: torch.Tensor, # shape [max-loras]
9295
block_size: int,
9396
num_experts: int,
9497
max_loras: int,
@@ -119,6 +122,9 @@ def moe_lora_align_block_size(
119122
ops.moe_lora_align_block_size(
120123
topk_ids,
121124
token_lora_mapping,
125+
num_tokens_per_lora,
126+
no_lora_flag_cpu,
127+
adapter_enabled,
122128
num_experts,
123129
block_size,
124130
max_loras,

0 commit comments

Comments
 (0)