diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1c5c04d9db8e..0e35d8a80988 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -66,6 +66,8 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 473816fcc3ec..cb1d5400f3a0 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -469,6 +469,7 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." @@ -476,6 +477,10 @@ def apply( raise NotImplementedError( "Expert Parallelism is not supported for " "fused Marlin MoE method.") + if apply_router_weight_on_input: + raise NotImplementedError( + "Apply router weight on input is not supported for" + "fused Marlin MoE method.") topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index d18ca55afebd..be19b80975ec 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -113,6 +113,7 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -129,18 +130,20 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts(x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - use_int8_w8a16=True, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_scale, - w2_scale=layer.w2_scale) + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + use_int8_w8a16=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_scale, + w2_scale=layer.w2_scale) @staticmethod def quantizing_weight_loader(layer, weight_loader): diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 9861e0a85b3f..6b499f81c55f 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -338,9 +338,15 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ): assert activation == "silu", "Only SiLU activation is supported." + if apply_router_weight_on_input: + raise NotImplementedError( + "Apply router weight on input is not supported for" + "fused GGUF MoE method.") + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 9f53ffc1d7f6..0615bb4ab4df 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -592,9 +592,14 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: assert activation == "silu", "Only SiLU activation is supported." + if apply_router_weight_on_input is not None: + raise NotImplementedError( + "Apply router weight on input is not supported for" + "fused Marlin MoE method.") # The input must currently be float16 orig_dtype = x.dtype diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 41b75c9be05a..00c4b661ef2c 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -293,6 +293,7 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -312,21 +313,23 @@ def apply( weight_bits = self.quant_config.weight_bits has_zp = self.quant_config.has_zp - return fused_experts(x, - layer.w13_qweight, - layer.w2_qweight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_int4_w4a16=weight_bits == 4, - use_int8_w8a16=weight_bits == 8, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_scales, - w2_scale=layer.w2_scales, - w1_zp=layer.w13_qzeros if has_zp else None, - w2_zp=layer.w2_qzeros if has_zp else None, - block_shape=[0, layer.group_size]) + return fused_experts( + x, + layer.w13_qweight, + layer.w2_qweight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_int4_w4a16=weight_bits == 4, + use_int8_w8a16=weight_bits == 8, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_scales, + w2_scale=layer.w2_scales, + w1_zp=layer.w13_qzeros if has_zp else None, + w2_zp=layer.w2_qzeros if has_zp else None, + block_shape=[0, layer.group_size]) @staticmethod def get_weight_loader(layer, weight_loader): diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index bc26a455c6f2..d1146c0f039d 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -202,6 +202,8 @@ def apply( custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", ) -> torch.Tensor: from vllm.model_executor.layers.fused_moe import fused_experts @@ -217,16 +219,18 @@ def apply( scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias) - return fused_experts(x, - layer.w13_weight, - layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - use_fp8_w8a8=True, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - a1_scale=layer.w13_input_scale, - a2_scale=layer.w2_input_scale) + return fused_experts( + x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale)