Skip to content

Commit 055486e

Browse files
author
Danielle Robinson
committed
add support for moe_sum through fused_marlin_moe
1 parent 77bbb51 commit 055486e

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

vllm/model_executor/layers/fused_moe/fused_marlin_moe.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def fused_marlin_moe(
4242
apply_router_weight_on_input: bool = False,
4343
global_num_experts: int = -1,
4444
activation: Optional[str] = "silu",
45-
activation_func: Optional[str] = None,
45+
activation_func: Optional[str] = None, # FIXME: type Callable
46+
moe_sum: Optional[str] = None, # FIXME: type Callable
4647
expert_map: Optional[torch.Tensor] = None,
4748
global_scale1: Optional[torch.Tensor] = None,
4849
global_scale2: Optional[torch.Tensor] = None,
@@ -240,12 +241,16 @@ def fused_marlin_moe(
240241
is_k_full=is_k_full,
241242
use_atomic_add=use_atomic_add,
242243
use_fp32_reduce=True,
244+
243245
is_zp_float=False,
244246
).view(-1, topk, K)
245247

246248
if output is None:
247249
output = hidden_states if inplace else torch.empty_like(hidden_states)
248-
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
250+
if moe_sum is None:
251+
return torch.sum(intermediate_cache3.view(-1, topk, K), dim=1, out=output)
252+
else:
253+
return moe_sum(intermediate_cache3, output)
249254

250255

251256
def fused_marlin_moe_fake(
@@ -407,6 +412,7 @@ def apply(
407412
global_num_experts=global_num_experts,
408413
activation=activation,
409414
activation_func=self.activation,
415+
moe_sum=self.moe_sum,
410416
expert_map=expert_map,
411417
output=output,
412418
# Workspaces are swapped in workspace_shapes() to account for proper

0 commit comments

Comments
 (0)