From 2199fff269cf96938e1c62f3b6d0f04d32c3e2f3 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 17 Oct 2025 19:34:55 +0000 Subject: [PATCH 01/16] [Kernel] Isolate modular kernel dispatching Signed-off-by: Bill Nell --- .../base_device_communicator.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 123 +++++++++++++++++- 2 files changed, 122 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 9566dbac7f22..d775f06f4894 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -273,7 +273,7 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None ) ] for module in moe_modules: - module.quant_method.init_prepare_finalize(module) + module.init_prepare_finalize() def dispatch( self, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 46d351b48c5e..78e45577c0ad 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -247,7 +247,9 @@ def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: # Note: init_prepare_finalize should only be called by # prepare_communication_buffer_for_model. - def init_prepare_finalize(self, layer: torch.nn.Module): + def init_prepare_finalize( + self, layer: torch.nn.Module + ) -> FusedMoEModularKernel | None: assert self.moe is not None # We must get the quant config here so that the layer is @@ -267,12 +269,14 @@ def init_prepare_finalize(self, layer: torch.nn.Module): ) self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() experts = self.select_gemm_impl(prepare_finalize, layer) - self.fused_experts = FusedMoEModularKernel( + return FusedMoEModularKernel( prepare_finalize, experts, layer.shared_experts, ) + return None + def select_gemm_impl( self, prepare_finalize: FusedMoEPrepareAndFinalize, @@ -322,6 +326,113 @@ def apply( raise NotImplementedError +@CustomOp.register("modular_fused_moe") +class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): + def __init__( + self, old_moe_method: FusedMoEMethodBase, fused_experts: FusedMoEModularKernel + ): + super().__init__(old_moe_method.moe) + # Find better way to copy attributes + # self.__dict__.update(old_moe_method.__dict__) + + self.moe_quant_config = old_moe_method.moe_quant_config + self.fused_experts = fused_experts + self.topk_indices_dtype = old_moe_method.topk_indices_dtype + + if isinstance(old_moe_method, torch.nn.Module): + self.load_state_dict(old_moe_method.state_dict()) + logger.debug("Swapping out %s", old_moe_method.__class__.__name__) + + def create_weights( + self, + layer: torch.nn.Module, + num_experts: int, + hidden_size: int, + intermediate_size_per_partition: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + raise NotImplementedError + + def get_fused_moe_quant_config( + self, layer: torch.nn.Module + ) -> FusedMoEQuantConfig | None: + return self.moe_quant_config + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: int | None = None, + num_expert_group: int | None = None, + global_num_experts: int = -1, + expert_map: torch.Tensor | None = None, + custom_routing_function: Callable | None = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + expert_load_view: torch.Tensor | None = None, + logical_to_physical_map: torch.Tensor | None = None, + logical_replica_count: torch.Tensor | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + # Is getattr needed? + zero_expert_num = getattr(layer, "zero_expert_num", 0) + zero_expert_type = getattr(layer, "zero_expert_type", None) + + select_result = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, + enable_eplb=enable_eplb, + expert_map=expert_map, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, + global_num_experts=global_num_experts, + zero_expert_num=zero_expert_num, + zero_expert_type=zero_expert_type, + ) + + topk_weights, topk_ids, zero_expert_result = select_result + + result = self.fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + global_num_experts=global_num_experts, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_map=expert_map, + ) + + if zero_expert_num != 0 and zero_expert_type is not None: + assert not isinstance(result, tuple), ( + "Shared + zero experts are mutually exclusive not yet supported" + ) + return result, zero_expert_result + else: + return result + + @CustomOp.register("unquantized_fused_moe") class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): """MoE method without quantization.""" @@ -1353,6 +1464,14 @@ def __init__( self.batched_hidden_states: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None + def init_prepare_finalize(self) -> None: + mk = self.quant_method.init_prepare_finalize(self) + if mk is not None: + new_quant_method = FusedMoEModularMethod(self.quant_method, mk) + if isinstance(self.quant_method, torch.nn.Module): + self.set_submodule(self.quant_method.name, new_quant_method) + self.quant_method = new_quant_method + @property def shared_experts(self) -> torch.nn.Module | None: return None From 7bfd9c371e7a1a61db8a5972753c9aad4cba7f56 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 17 Oct 2025 19:42:01 +0000 Subject: [PATCH 02/16] cleanup Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 78e45577c0ad..996993bf9762 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -245,8 +245,6 @@ def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: else: return None - # Note: init_prepare_finalize should only be called by - # prepare_communication_buffer_for_model. def init_prepare_finalize( self, layer: torch.nn.Module ) -> FusedMoEModularKernel | None: @@ -274,8 +272,8 @@ def init_prepare_finalize( experts, layer.shared_experts, ) - - return None + else: + return None def select_gemm_impl( self, @@ -329,7 +327,9 @@ def apply( @CustomOp.register("modular_fused_moe") class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): def __init__( - self, old_moe_method: FusedMoEMethodBase, fused_experts: FusedMoEModularKernel + self, + old_moe_method: FusedMoEMethodBase, + fused_experts: FusedMoEModularKernel, ): super().__init__(old_moe_method.moe) # Find better way to copy attributes @@ -382,6 +382,8 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + assert self.fused_experts is not None + # Is getattr needed? zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_type = getattr(layer, "zero_expert_type", None) @@ -1464,6 +1466,8 @@ def __init__( self.batched_hidden_states: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None + # Note: init_prepare_finalize should only be called by + # prepare_communication_buffer_for_model. def init_prepare_finalize(self) -> None: mk = self.quant_method.init_prepare_finalize(self) if mk is not None: From 2f3f631e10e30ea8a477708b74c733f26535b6b5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 17 Oct 2025 20:51:16 +0000 Subject: [PATCH 03/16] remove uses of self.fused_experts Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 55 +++++----- .../layers/fused_moe/modular_kernel.py | 9 ++ .../layers/quantization/awq_marlin.py | 2 - .../layers/quantization/bitsandbytes.py | 3 +- .../compressed_tensors_moe.py | 47 -------- .../layers/quantization/experts_int8.py | 2 - .../model_executor/layers/quantization/fp8.py | 31 +----- .../layers/quantization/gguf.py | 2 - .../layers/quantization/gptq_marlin.py | 2 - .../layers/quantization/modelopt.py | 47 ++------ .../layers/quantization/moe_wna16.py | 1 - .../layers/quantization/mxfp4.py | 101 +++--------------- .../layers/quantization/quark/quark_moe.py | 49 ++++----- .../model_executor/layers/quantization/rtn.py | 2 - 14 files changed, 89 insertions(+), 264 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 996993bf9762..fa0c93d6c3ab 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -119,7 +119,6 @@ def __init__(self, moe: FusedMoEConfig): super().__init__() self.moe = moe self.moe_quant_config: FusedMoEQuantConfig | None = None - self.fused_experts: FusedMoEModularKernel | None = None self.topk_indices_dtype = None @abstractmethod @@ -262,9 +261,6 @@ def init_prepare_finalize( "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) ) assert self.topk_indices_dtype is None - assert self.fused_experts is None, ( - f"Attempt to override experts for {id(self)}!" - ) self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() experts = self.select_gemm_impl(prepare_finalize, layer) return FusedMoEModularKernel( @@ -295,7 +291,11 @@ def get_fused_moe_quant_config( @property def using_modular_kernel(self) -> bool: - return self.fused_experts is not None + return False + + @property + def supports_eplb(self) -> bool: + return False @abstractmethod def apply( @@ -338,10 +338,21 @@ def __init__( self.moe_quant_config = old_moe_method.moe_quant_config self.fused_experts = fused_experts self.topk_indices_dtype = old_moe_method.topk_indices_dtype - + self.disable_expert_map = not fused_experts.supports_expert_map() + self.old_method_name = old_moe_method.__class__.__name__ + self._supports_eplb = old_moe_method.supports_eplb if isinstance(old_moe_method, torch.nn.Module): self.load_state_dict(old_moe_method.state_dict()) - logger.debug("Swapping out %s", old_moe_method.__class__.__name__) + logger.debug("Swapping out %s", self.old_method_name) + + @property + def using_modular_kernel(self) -> bool: + return True + + @property + @abstractmethod + def supports_eplb(self) -> bool: + return self._supports_eplb def create_weights( self, @@ -382,12 +393,21 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is not None - # Is getattr needed? zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_type = getattr(layer, "zero_expert_type", None) + if enable_eplb: + if not self.supports_eplb: + assert expert_load_view is not None + assert logical_to_physical_map is not None + assert logical_replica_count is not None + assert isinstance(layer, FusedMoE) + else: + raise NotImplementedError( + f"EPLB is not supported for {self.old_method_name}" + ) + select_result = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -423,7 +443,7 @@ def apply( activation=activation, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, + expert_map=None if self.disable_expert_map else expert_map, ) if zero_expert_num != 0 and zero_expert_type is not None: @@ -763,7 +783,6 @@ def forward_cuda( ) if self.rocm_aiter_moe_enabled: - assert self.fused_experts is None result = self.rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -784,21 +803,7 @@ def forward_cuda( activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, ) - elif self.fused_experts is not None: - result = self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, - ) else: - assert fused_experts is not None result = fused_experts( hidden_states=x, w1=layer.w13_weight, diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 3b5916f8ccaf..8ea5b6b34e43 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -707,6 +707,15 @@ def __init__( f"{fused_experts.activation_formats[0]}" ) + def supports_expert_map(self) -> bool: + """ + A flag indicating whether or not this class supports expert maps + """ + return ( + self.prepare_finalize.num_dispatchers() <= 1 + and self.fused_experts.supports_expert_map() + ) + def output_is_reduced(self) -> bool: """ Indicates whether or not the output of fused MoE kernel diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index daf7422963f3..3e1f87b59a34 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -617,8 +617,6 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError("EPLB not supported for `AWQMoEMethod` yet.") diff --git a/vllm/model_executor/layers/quantization/bitsandbytes.py b/vllm/model_executor/layers/quantization/bitsandbytes.py index ccd9b311cc93..e5a741e639ad 100644 --- a/vllm/model_executor/layers/quantization/bitsandbytes.py +++ b/vllm/model_executor/layers/quantization/bitsandbytes.py @@ -518,12 +518,11 @@ def apply( ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: from vllm.model_executor.layers.fused_moe import fused_experts - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `BitsAndBytesMoEMethod` yet." ) + topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index bf38c15b4701..d95d49eddfe3 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -462,12 +462,7 @@ def apply( indices_type=self.topk_indices_dtype, ) - # - # Note: the order here is important. self.fused_experts can override - # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin. - # if self.use_marlin: - assert self.fused_experts is None return fused_marlin_moe( x, layer.w13_weight, @@ -488,24 +483,6 @@ def apply( workspace=layer.workspace, ) - elif self.fused_experts is not None: - assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight - ), "Flashinfer CUTLASS Fused MoE not applicable!" - - return self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=False, # TODO(shuw): fix later, now output is high prec - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - # FlashInfer fused experts path elif self.allow_flashinfer: from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 @@ -1066,13 +1043,8 @@ def apply( per_act_token = self.input_quant.strategy == QuantizationStrategy.TOKEN per_channel_quant = self.weight_quant.strategy == QuantizationStrategy.CHANNEL - # - # Note: the order here is important. self.fused_experts can override - # cutlass fp8 or fused_experts but not marlin or rocm. - # if self.use_marlin: assert activation == "silu", f"{activation} not supported for Marlin MoE." - assert self.fused_experts is None return fused_marlin_moe( x, layer.w13_weight, @@ -1098,7 +1070,6 @@ def apply( assert per_act_token == per_channel_quant assert self.moe_quant_config is not None - assert self.fused_experts is None return rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -1111,18 +1082,6 @@ def apply( quant_config=self.moe_quant_config, ) - elif self.fused_experts is not None: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=None if self.disable_expert_map else expert_map, - ) - # cutlass path elif self.use_cutlass: assert self.moe_quant_config is not None @@ -1318,8 +1277,6 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `CompressedTensorsW8A8Int8MoEMethod` yet." @@ -1636,8 +1593,6 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `CompressedTensorsWNA16MarlinMoEMethod` yet." @@ -1901,8 +1856,6 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `CompressedTensorsWNA16MoEMethod` yet." diff --git a/vllm/model_executor/layers/quantization/experts_int8.py b/vllm/model_executor/layers/quantization/experts_int8.py index 754608af97c6..5241f9a2301b 100644 --- a/vllm/model_executor/layers/quantization/experts_int8.py +++ b/vllm/model_executor/layers/quantization/experts_int8.py @@ -158,8 +158,6 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `ExpertsInt8MoEMethod` yet." diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f82eccb88ce0..e01ee8356502 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -703,9 +703,6 @@ def __init__(self, quant_config: Fp8Config, layer: torch.nn.Module): self.quant_config = quant_config self.weight_block_size = self.quant_config.weight_block_size self.block_quant: bool = self.weight_block_size is not None - - self.fused_experts: mk.FusedMoEModularKernel | None = None # type: ignore - self.fp8_backend = get_fp8_moe_backend(self.block_quant) self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN @@ -1181,6 +1178,10 @@ def get_fused_moe_quant_config( block_shape=self.weight_block_size, ) + @property + def supports_eplb(self) -> bool: + return True + def apply( self, layer: torch.nn.Module, @@ -1210,10 +1211,7 @@ def apply( assert logical_replica_count is not None assert isinstance(layer, FusedMoE) - if ( - self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM - and self.fused_experts is None - ): + if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) @@ -1290,10 +1288,6 @@ def apply( num_fused_shared_experts=layer.num_fused_shared_experts, ) - # - # Note: the order of checks is important since self.fused_experts - # can override fused_experts or cutlass but not rocm or marlin. - # topk_weights, topk_ids, zero_expert_result = select_result if self.rocm_aiter_moe_enabled: @@ -1301,7 +1295,6 @@ def apply( rocm_aiter_fused_experts, ) - assert self.fused_experts is None result = rocm_aiter_fused_experts( x, layer.w13_weight, @@ -1315,7 +1308,6 @@ def apply( ) elif self.use_marlin: assert activation == "silu", f"{activation} not supported for Marlin MoE." - assert self.fused_experts is None result = fused_marlin_moe( x, layer.w13_weight, @@ -1333,19 +1325,6 @@ def apply( expert_map=expert_map, workspace=layer.workspace, ) - elif self.fused_experts: - result = self.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - global_num_experts=global_num_experts, - apply_router_weight_on_input=apply_router_weight_on_input, - expert_map=expert_map, - ) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: assert not self.block_quant assert not renormalize and custom_routing_function is not None diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 8a914c57a9f7..caabcd0ca0ee 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -585,8 +585,6 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError("EPLB not supported for `GGUFMoEMethod` yet.") diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 0d5439357fda..42a569e7770c 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -742,8 +742,6 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `GPTQMarlinMoEMethod` yet." diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 37b682984fc3..67dc4966892a 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -605,7 +605,6 @@ def apply( ) if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: - assert self.fused_experts is None assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" ) @@ -638,24 +637,7 @@ def apply( indices_type=self.topk_indices_dtype, ) - # - # Note: the order here is important. self.fused_experts can override - # cutlass or fused_experts. - # - if self.fused_experts is not None: - return self.fused_experts( - x, - layer.w13_weight, - layer.w2_weight, - topk_weights, - topk_ids, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: + if self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: assert not renormalize assert activation == "silu", ( f"Expected 'silu' activation but got {activation}" @@ -1647,8 +1629,6 @@ def apply( from vllm.model_executor.models.llama4 import Llama4MoE - assert self.fused_experts is None - a1_gscale = layer.w13_input_scale_quant (hidden_states_fp4, hidden_states_scale_linear_fp4) = ( flashinfer.fp4_quantize( @@ -1720,13 +1700,7 @@ def apply( indices_type=self.topk_indices_dtype, ) - # - # Note: the order here is important. self.fused_experts can override - # flashinfer cutlass, cutlass fp4 or fused_experts but not marlin or - # trtllm. - # if self.use_marlin: - assert self.fused_experts is None return fused_marlin_moe( x, layer.w13_weight, @@ -1747,23 +1721,24 @@ def apply( workspace=layer.workspace, ) - elif self.fused_experts is not None: - assert ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + elif ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + ): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + flashinfer_cutlass_moe_fp4, ) - assert is_valid_flashinfer_cutlass_fused_moe( - x, layer.w13_weight, layer.w2_weight - ), "Flashinfer CUTLASS Fused MoE not applicable!" + assert self.moe_quant_config is not None - return self.fused_experts( + return flashinfer_cutlass_moe_fp4( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=False, # TODO(shuw): fix later, now output is high prec + quant_config=self.moe_quant_config, + inplace=False, activation=activation, global_num_experts=global_num_experts, expert_map=expert_map, diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index b0a268b9950b..08b92a6b99a3 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -381,7 +381,6 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None if enable_eplb: raise NotImplementedError("EPLB not supported for `MoeWNA16Method` yet.") diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 597ee1b6bafe..cb73e44c7b0a 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -815,6 +815,18 @@ def select_gemm_impl( "EP batched experts format" ) else: + layer.w13_weight = ( + self.w13_weight_triton_tensor + if layer.w13_weight is None + else layer.w13_weight + ) + layer.w2_weight = ( + self.w2_weight_triton_tensor + if layer.w2_weight is None + else layer.w2_weight + ) + assert all([w is not None for w in [layer.w13_weight, layer.w2_weight]]) + assert self.moe_quant_config is not None if ( self.mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_TRTLLM @@ -838,72 +850,6 @@ def select_gemm_impl( f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP" ) - def _route_and_experts( - self, - layer: torch.nn.Module, - x: torch.Tensor, - router_logits: torch.Tensor, - top_k: int, - renormalize: bool, - use_grouped_topk: bool = False, - topk_group: int | None = None, - num_expert_group: int | None = None, - global_num_experts: int = -1, - expert_map: torch.Tensor | None = None, - custom_routing_function: Callable | None = None, - scoring_func: str = "softmax", - e_score_correction_bias: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: torch.Tensor | None = None, - logical_to_physical_map: torch.Tensor | None = None, - logical_replica_count: torch.Tensor | None = None, - ) -> torch.Tensor: - assert isinstance(self.fused_experts, mk.FusedMoEModularKernel) - - topk_weights, topk_ids, _ = FusedMoE.select_experts( - hidden_states=x, - router_logits=router_logits, - use_grouped_topk=use_grouped_topk, - top_k=top_k, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - e_score_correction_bias=e_score_correction_bias, - indices_type=self.topk_indices_dtype, - enable_eplb=enable_eplb, - expert_map=expert_map, - expert_load_view=expert_load_view, - logical_to_physical_map=logical_to_physical_map, - logical_replica_count=logical_replica_count, - ) - - w13_weight = ( - self.w13_weight_triton_tensor - if layer.w13_weight is None - else layer.w13_weight - ) - w2_weight = ( - self.w2_weight_triton_tensor if layer.w2_weight is None else layer.w2_weight - ) - assert all([w is not None for w in [w13_weight, w2_weight]]) - - return self.fused_experts( - hidden_states=x, - w1=w13_weight, - w2=w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - def apply( self, layer: torch.nn.Module, @@ -930,29 +876,6 @@ def apply( if enable_eplb: raise NotImplementedError("EPLB is not supported for mxfp4") - if self.fused_experts is not None: - return self._route_and_experts( - layer, - x, - router_logits, - top_k, - renormalize, - use_grouped_topk, - topk_group, - num_expert_group, - global_num_experts, - expert_map, - custom_routing_function, - scoring_func, - e_score_correction_bias, - apply_router_weight_on_input, - activation, - enable_eplb, - expert_load_view, - logical_to_physical_map, - logical_replica_count, - ) - if self.mxfp4_backend == Mxfp4Backend.MARLIN: topk_weights, topk_ids, _ = FusedMoE.select_experts( hidden_states=x, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index a8f4b1b0db68..1631dfc4dfe4 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -310,7 +310,6 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: # Property to determine if AITER is used if self.rocm_aiter_moe_enabled: from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa E501 - rocm_aiter_fused_experts, shuffle_weights, ) @@ -322,17 +321,11 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False) layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False) - self.rocm_aiter_fused_experts_func = rocm_aiter_fused_experts elif self.use_marlin: prepare_moe_fp8_layer_for_marlin(layer, False) # Activations not quantized for marlin. del layer.w13_input_scale del layer.w2_input_scale - self.fused_experts_func = None - else: - from vllm.model_executor.layers.fused_moe import fused_experts - - self.fused_experts_func = fused_experts def get_fused_moe_quant_config( self, layer: torch.nn.Module @@ -369,8 +362,6 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `QuarkW8A8Fp8MoEMethod` yet." @@ -392,7 +383,11 @@ def apply( ) if self.rocm_aiter_moe_enabled: - return self.rocm_aiter_fused_experts_func( + from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( + rocm_aiter_fused_experts, + ) + + return rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -403,7 +398,7 @@ def apply( quant_config=self.moe_quant_config, expert_map=expert_map, ) - if self.use_marlin: + elif self.use_marlin: assert activation == "silu", f"{activation} not supported for Marlin MoE." return fused_marlin_moe( x, @@ -421,22 +416,22 @@ def apply( global_num_experts=global_num_experts, expert_map=expert_map, ) + else: + from vllm.model_executor.layers.fused_moe import fused_experts - assert self.fused_experts_func is not None - - return self.fused_experts_func( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - inplace=True, - activation=activation, - apply_router_weight_on_input=apply_router_weight_on_input, - global_num_experts=global_num_experts, - expert_map=expert_map, - quant_config=self.moe_quant_config, - ) + return fused_experts( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + global_num_experts=global_num_experts, + expert_map=expert_map, + quant_config=self.moe_quant_config, + ) class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): @@ -624,8 +619,6 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError( "EPLB not supported for `QuarkOCP_MX_MoEMethod` yet." diff --git a/vllm/model_executor/layers/quantization/rtn.py b/vllm/model_executor/layers/quantization/rtn.py index e4f7ff833956..52656263a601 100644 --- a/vllm/model_executor/layers/quantization/rtn.py +++ b/vllm/model_executor/layers/quantization/rtn.py @@ -377,8 +377,6 @@ def apply( logical_to_physical_map: torch.Tensor | None = None, logical_replica_count: torch.Tensor | None = None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - assert self.fused_experts is None - if enable_eplb: raise NotImplementedError("EPLB not supported for `RTNMoEMethod` yet.") From 9d1f2fbe4d43b9a4a7e2b396357f6ade9c0c7384 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 17 Oct 2025 20:52:20 +0000 Subject: [PATCH 04/16] remove uses of self.fused_experts Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index fa0c93d6c3ab..f6b1ad0b1468 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -350,7 +350,6 @@ def using_modular_kernel(self) -> bool: return True @property - @abstractmethod def supports_eplb(self) -> bool: return self._supports_eplb From 38799ff0de5435deeac657c29c8ba1f491b9168e Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 17 Oct 2025 20:53:07 +0000 Subject: [PATCH 05/16] comment Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f6b1ad0b1468..f13514315218 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -332,7 +332,7 @@ def __init__( fused_experts: FusedMoEModularKernel, ): super().__init__(old_moe_method.moe) - # Find better way to copy attributes + # Find better way to copy attributes? # self.__dict__.update(old_moe_method.__dict__) self.moe_quant_config = old_moe_method.moe_quant_config From ad9e3b6fe3bb3fc2cbe2979c7355bf492af43ab0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 17 Oct 2025 21:06:04 +0000 Subject: [PATCH 06/16] cleanups Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index f13514315218..a5dd0308cc1a 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -289,10 +289,6 @@ def get_fused_moe_quant_config( ) -> FusedMoEQuantConfig | None: raise NotImplementedError - @property - def using_modular_kernel(self) -> bool: - return False - @property def supports_eplb(self) -> bool: return False @@ -345,10 +341,6 @@ def __init__( self.load_state_dict(old_moe_method.state_dict()) logger.debug("Swapping out %s", self.old_method_name) - @property - def using_modular_kernel(self) -> bool: - return True - @property def supports_eplb(self) -> bool: return self._supports_eplb @@ -1472,13 +1464,12 @@ def __init__( # Note: init_prepare_finalize should only be called by # prepare_communication_buffer_for_model. + # This is called after all weight loading and post-processing, so it + # should be safe to swap out the quant_method. def init_prepare_finalize(self) -> None: mk = self.quant_method.init_prepare_finalize(self) if mk is not None: - new_quant_method = FusedMoEModularMethod(self.quant_method, mk) - if isinstance(self.quant_method, torch.nn.Module): - self.set_submodule(self.quant_method.name, new_quant_method) - self.quant_method = new_quant_method + self.quant_method = FusedMoEModularMethod(self.quant_method, mk) @property def shared_experts(self) -> torch.nn.Module | None: @@ -2294,7 +2285,7 @@ def must_reduce_shared_expert_outputs(self) -> bool: """ assert self.quant_method is not None return ( - self.quant_method.fused_experts is not None + isinstance(self.quant_method, FusedMoEModularMethod) and self.quant_method.fused_experts.output_is_reduced() ) @@ -2530,7 +2521,7 @@ def forward_impl( self.ensure_dp_chunking_init() has_separate_shared_experts = ( - not isinstance(self.quant_method.fused_experts, FusedMoEModularKernel) + not isinstance(self.quant_method, FusedMoEModularMethod) and self.shared_experts is not None ) @@ -2557,8 +2548,8 @@ def forward_impl( hidden_states, router_logits, has_separate_shared_experts ) - do_naive_dispatch_combine: bool = ( - self.dp_size > 1 and not self.quant_method.using_modular_kernel + do_naive_dispatch_combine: bool = self.dp_size > 1 and not isinstance( + self.quant_method, FusedMoEModularMethod ) # If there are shared experts but we are not using a modular kernel, the From b12ed234a6ceeba3e8f3363306d217d07919cdea Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 17 Oct 2025 21:07:49 +0000 Subject: [PATCH 07/16] fix Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a5dd0308cc1a..0ca4efaf2fd9 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -389,7 +389,7 @@ def apply( zero_expert_type = getattr(layer, "zero_expert_type", None) if enable_eplb: - if not self.supports_eplb: + if self.supports_eplb: assert expert_load_view is not None assert logical_to_physical_map is not None assert logical_replica_count is not None From 9b0af526a753df755e1e356b2e1ce6d89c377cf2 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 18 Oct 2025 19:58:04 +0000 Subject: [PATCH 08/16] comment Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 0ca4efaf2fd9..1edd3bac898d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -430,7 +430,7 @@ def apply( w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, + inplace=True, # TODO(bnell): make sure this is handled properly activation=activation, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, From 3048d6e76f3476f234400ef4dab5dfc4425ba061 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 18 Oct 2025 20:09:12 +0000 Subject: [PATCH 09/16] fix inplace Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 19 ++++++++++++++++++- .../model_executor/layers/quantization/fp8.py | 4 ++++ .../layers/quantization/mxfp4.py | 4 ++++ .../layers/quantization/quark/quark_moe.py | 4 ++++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 1edd3bac898d..a0ce6c515019 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -293,6 +293,10 @@ def get_fused_moe_quant_config( def supports_eplb(self) -> bool: return False + @property + def allow_inplace(self) -> bool: + return False + @abstractmethod def apply( self, @@ -337,6 +341,7 @@ def __init__( self.disable_expert_map = not fused_experts.supports_expert_map() self.old_method_name = old_moe_method.__class__.__name__ self._supports_eplb = old_moe_method.supports_eplb + self._allow_inplace = old_moe_method.allow_inplace if isinstance(old_moe_method, torch.nn.Module): self.load_state_dict(old_moe_method.state_dict()) logger.debug("Swapping out %s", self.old_method_name) @@ -345,6 +350,10 @@ def __init__( def supports_eplb(self) -> bool: return self._supports_eplb + @property + def allow_inplace(self) -> bool: + return self._allow_inplace + def create_weights( self, layer: torch.nn.Module, @@ -430,7 +439,7 @@ def apply( w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, - inplace=True, # TODO(bnell): make sure this is handled properly + inplace=self.allow_inplace, activation=activation, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, @@ -502,6 +511,14 @@ def __init__(self, moe: FusedMoEConfig): ) self.flashinfer_cutlass_moe = None # type: ignore + @property + def supports_eplb(self) -> bool: + return True + + @property + def allow_inplace(self) -> bool: + return True + def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: if self.rocm_aiter_moe_enabled: return None diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index e01ee8356502..03eca199d536 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1182,6 +1182,10 @@ def get_fused_moe_quant_config( def supports_eplb(self) -> bool: return True + @property + def allow_inplace(self) -> bool: + return True + def apply( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index cb73e44c7b0a..651bc8844887 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -850,6 +850,10 @@ def select_gemm_impl( f"Incompatible Mxfp4 backend ({self.mxfp4_backend}) for EP" ) + @property + def allow_inplace(self) -> bool: + return True + def apply( self, layer: torch.nn.Module, diff --git a/vllm/model_executor/layers/quantization/quark/quark_moe.py b/vllm/model_executor/layers/quantization/quark/quark_moe.py index 1631dfc4dfe4..8825611051e5 100644 --- a/vllm/model_executor/layers/quantization/quark/quark_moe.py +++ b/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -596,6 +596,10 @@ def get_fused_moe_quant_config( block_shape=None, ) + @property + def allow_inplace(self) -> bool: + return True + def apply( self, layer: torch.nn.Module, From 9d6f340db94fc2361a37460897b9998ee6ba6fb5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 18 Oct 2025 21:42:44 +0000 Subject: [PATCH 10/16] fix warmup code Signed-off-by: Bill Nell --- vllm/model_executor/warmup/deep_gemm_warmup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index 78cbcd8e5427..bdcebd498ef0 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -13,7 +13,7 @@ from vllm.distributed.parallel_state import get_dp_group from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts from vllm.model_executor.layers.fused_moe.deep_gemm_utils import compute_aligned_M -from vllm.model_executor.layers.fused_moe.layer import FusedMoE +from vllm.model_executor.layers.fused_moe.layer import FusedMoE, FusedMoEModularMethod from vllm.model_executor.layers.fused_moe.modular_kernel import FusedMoEModularKernel from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, @@ -160,8 +160,8 @@ def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: ): return False - if not isinstance(module.quant_method.fused_experts, FusedMoEModularKernel): - # fused_experts could invoke deep_gemm_moe_fp8 + if not isinstance(module.quant_method, FusedMoEModularMethod): + # modular kernels could invoke deep_gemm_moe_fp8 return True mk: FusedMoEModularKernel = module.quant_method.fused_experts From f1d0b24ca696120068b36e06da94dda9542beda0 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Sat, 18 Oct 2025 22:42:44 +0000 Subject: [PATCH 11/16] remove unused import Signed-off-by: Bill Nell --- vllm/model_executor/layers/quantization/modelopt.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 67dc4966892a..f61d2a52925d 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -18,9 +18,6 @@ fp8_w8a8_moe_quant_config, nvfp4_moe_quant_config, ) -from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( - is_valid_flashinfer_cutlass_fused_moe, -) from vllm.model_executor.layers.fused_moe.fused_marlin_moe import fused_marlin_moe from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, From e8be0304bfb31a9eef58cb4329bdcb9b5012791a Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 20 Oct 2025 21:49:17 +0000 Subject: [PATCH 12/16] clean up object types and initialization Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 55 +++++++++---------- 1 file changed, 27 insertions(+), 28 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a0ce6c515019..49cde8f55fe3 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1393,7 +1393,7 @@ def __init__( "Only softmax scoring function is supported for non-grouped topk." ) - moe = FusedMoEConfig( + self.moe_config: FusedMoEConfig = FusedMoEConfig( num_experts=self.global_num_experts, experts_per_token=top_k, hidden_dim=hidden_size, @@ -1405,24 +1405,26 @@ def __init__( is_act_and_mul=is_act_and_mul, is_lora_enabled=vllm_config.lora_config is not None, ) - self.moe_config: FusedMoEConfig = moe + self.moe_quant_config: FusedMoEQuantConfig | None = None self.quant_config = quant_config + def _get_quant_method() -> FusedMoEMethodBase: + """ + Helper method to ensure self.quant_method is never None and + of the proper type. + """ + quant_method = None + if self.quant_config is not None: + quant_method = self.quant_config.get_quant_method(self, prefix) + if quant_method is None: + quant_method = UnquantizedFusedMoEMethod(self.moe_config) + assert isinstance(quant_method, FusedMoEMethodBase) + return quant_method + # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. - quant_method: QuantizeMethodBase | None = None - quant_method = ( - UnquantizedFusedMoEMethod(moe) - if quant_config is None - else quant_config.get_quant_method(self, prefix) - ) - if quant_method is None: - quant_method = UnquantizedFusedMoEMethod(moe) - - assert quant_method is not None - assert isinstance(quant_method, FusedMoEMethodBase) - self.quant_method = quant_method + self.quant_method: FusedMoEMethodBase = _get_quant_method() if not self.moe_config.is_act_and_mul: # Avoid circular import @@ -1442,20 +1444,17 @@ def __init__( "is_act_and_mul=False is supported only for CUDA for now" ) - if self.enable_eplb: - from vllm.model_executor.layers.quantization.fp8 import Fp8MoEMethod - - if not isinstance(quant_method, (Fp8MoEMethod, UnquantizedFusedMoEMethod)): - # TODO: Add support for additional quantization methods. - # The implementation for other quantization methods does not - # contain essential differences, but the current quant API - # design causes duplicated work when extending to new - # quantization methods, so I'm leaving it for now. - # If you plan to add support for more quantization methods, - # please refer to the implementation in `Fp8MoEMethod`. - raise NotImplementedError( - "EPLB is only supported for FP8 quantization for now." - ) + if self.enable_eplb and not self.quant_method.supports_eplb: + # TODO: Add support for additional quantization methods. + # The implementation for other quantization methods does not + # contain essential differences, but the current quant API + # design causes duplicated work when extending to new + # quantization methods, so I'm leaving it for now. + # If you plan to add support for more quantization methods, + # please refer to the implementation in `Fp8MoEMethod`. + raise NotImplementedError( + "EPLB is only supported for FP8 quantization for now." + ) moe_quant_params = { "num_experts": self.local_num_experts, From bc0cf464a55800a12521d4e59ecb34a6bf783727 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 24 Oct 2025 22:10:14 +0000 Subject: [PATCH 13/16] review comments Signed-off-by: Bill Nell --- .../base_device_communicator.py | 4 +- vllm/model_executor/layers/fused_moe/layer.py | 59 ++++++++++--------- .../layers/fused_moe/modular_kernel.py | 7 +-- .../layers/quantization/moe_wna16.py | 1 - .../layers/quantization/mxfp4.py | 2 - 5 files changed, 35 insertions(+), 38 deletions(-) diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index d775f06f4894..3a849da70e4c 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -266,14 +266,14 @@ def prepare_communication_buffer_for_model(self, model: torch.nn.Module) -> None module for module in model.modules() # TODO(bnell): Should use isinstance but can't. Maybe search for - # presence of quant_method.init_prepare_finalize? + # presence of quant_method.maybe_init_modular_kernel? if ( module.__class__.__name__ == "FusedMoE" or module.__class__.__name__ == "SharedFusedMoE" ) ] for module in moe_modules: - module.init_prepare_finalize() + module.maybe_init_modular_kernel() def dispatch( self, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 49cde8f55fe3..ab1df2f79f4d 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -119,7 +119,6 @@ def __init__(self, moe: FusedMoEConfig): super().__init__() self.moe = moe self.moe_quant_config: FusedMoEQuantConfig | None = None - self.topk_indices_dtype = None @abstractmethod def create_weights( @@ -244,7 +243,7 @@ def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None: else: return None - def init_prepare_finalize( + def maybe_init_modular_kernel( self, layer: torch.nn.Module ) -> FusedMoEModularKernel | None: assert self.moe is not None @@ -260,8 +259,6 @@ def init_prepare_finalize( logger.debug( "%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self) ) - assert self.topk_indices_dtype is None - self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() experts = self.select_gemm_impl(prepare_finalize, layer) return FusedMoEModularKernel( prepare_finalize, @@ -289,6 +286,10 @@ def get_fused_moe_quant_config( ) -> FusedMoEQuantConfig | None: raise NotImplementedError + @property + def topk_indices_dtype(self) -> torch.dtype | None: + return None + @property def supports_eplb(self) -> bool: return False @@ -328,31 +329,33 @@ def apply( class FusedMoEModularMethod(FusedMoEMethodBase, CustomOp): def __init__( self, - old_moe_method: FusedMoEMethodBase, + old_quant_method: FusedMoEMethodBase, fused_experts: FusedMoEModularKernel, ): - super().__init__(old_moe_method.moe) - # Find better way to copy attributes? - # self.__dict__.update(old_moe_method.__dict__) - - self.moe_quant_config = old_moe_method.moe_quant_config + super().__init__(old_quant_method.moe) + # Find better way to copy attributes? Should we even copy attributes? + # self.__dict__.update(old_quant_method.__dict__) + self.moe_quant_config = old_quant_method.moe_quant_config self.fused_experts = fused_experts - self.topk_indices_dtype = old_moe_method.topk_indices_dtype - self.disable_expert_map = not fused_experts.supports_expert_map() - self.old_method_name = old_moe_method.__class__.__name__ - self._supports_eplb = old_moe_method.supports_eplb - self._allow_inplace = old_moe_method.allow_inplace - if isinstance(old_moe_method, torch.nn.Module): - self.load_state_dict(old_moe_method.state_dict()) - logger.debug("Swapping out %s", self.old_method_name) + self.disable_expert_map = getattr( + old_quant_method, + "disable_expert_map", + not fused_experts.supports_expert_map(), + ) + self.old_quant_method = old_quant_method + logger.debug("Swapping out %s", self.old_quant_method.__class__.__name__) + + @property + def topk_indices_dtype(self) -> torch.dtype | None: + return self.fused_experts.prepare_finalize.topk_indices_dtype() @property def supports_eplb(self) -> bool: - return self._supports_eplb + return self.old_quant_method.supports_eplb @property def allow_inplace(self) -> bool: - return self._allow_inplace + return self.old_quant_method.allow_inplace def create_weights( self, @@ -405,10 +408,11 @@ def apply( assert isinstance(layer, FusedMoE) else: raise NotImplementedError( - f"EPLB is not supported for {self.old_method_name}" + "EPLB is not supported for " + f"{self.old_quant_method.__class__.__name__}." ) - select_result = FusedMoE.select_experts( + topk_weights, topk_ids, zero_expert_result = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, @@ -431,8 +435,6 @@ def apply( zero_expert_type=zero_expert_type, ) - topk_weights, topk_ids, zero_expert_result = select_result - result = self.fused_experts( hidden_states=x, w1=layer.w13_weight, @@ -1433,7 +1435,7 @@ def _get_quant_method() -> FusedMoEMethodBase: ) if not isinstance( - quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod) + self.quant_method, (UnquantizedFusedMoEMethod, ModelOptFp8MoEMethod) ): raise NotImplementedError( "is_act_and_mul=False is supported only for unquantized " @@ -1453,6 +1455,7 @@ def _get_quant_method() -> FusedMoEMethodBase: # If you plan to add support for more quantization methods, # please refer to the implementation in `Fp8MoEMethod`. raise NotImplementedError( + f"EPLB is not supported {self.quant_method.__class__.__name__}. " "EPLB is only supported for FP8 quantization for now." ) @@ -1478,12 +1481,12 @@ def _get_quant_method() -> FusedMoEMethodBase: self.batched_hidden_states: torch.Tensor | None = None self.batched_router_logits: torch.Tensor | None = None - # Note: init_prepare_finalize should only be called by + # Note: maybe_init_modular_kernel should only be called by # prepare_communication_buffer_for_model. # This is called after all weight loading and post-processing, so it # should be safe to swap out the quant_method. - def init_prepare_finalize(self) -> None: - mk = self.quant_method.init_prepare_finalize(self) + def maybe_init_modular_kernel(self) -> None: + mk = self.quant_method.maybe_init_modular_kernel(self) if mk is not None: self.quant_method = FusedMoEModularMethod(self.quant_method, mk) diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 8ea5b6b34e43..b5fa2c71bec5 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -709,12 +709,9 @@ def __init__( def supports_expert_map(self) -> bool: """ - A flag indicating whether or not this class supports expert maps + A flag indicating whether or not this class supports expert maps. """ - return ( - self.prepare_finalize.num_dispatchers() <= 1 - and self.fused_experts.supports_expert_map() - ) + return self.fused_experts.supports_expert_map() def output_is_reduced(self) -> bool: """ diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 08b92a6b99a3..2090c86f78dc 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -226,7 +226,6 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): - self.moe = layer layer.quant_config = self.quant_config bit8_pack_factor = self.quant_config.bit8_pack_factor group_size = self.quant_config.group_size diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 651bc8844887..8e5c75e663d8 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -197,8 +197,6 @@ def get_quant_method( class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) - self.topk_indices_dtype = None - self.moe = moe self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size From d1b98d5c8e4a3a69095f0b8d3104448e5e5b1ea5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Fri, 24 Oct 2025 22:12:38 +0000 Subject: [PATCH 14/16] add type annotation Signed-off-by: Bill Nell --- vllm/model_executor/layers/fused_moe/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index ab1df2f79f4d..d085663e3bbb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -117,7 +117,7 @@ class FusedMoeWeightScaleSupported(Enum): class FusedMoEMethodBase(QuantizeMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__() - self.moe = moe + self.moe: FusedMoEConfig = moe self.moe_quant_config: FusedMoEQuantConfig | None = None @abstractmethod From 5830b98f742a30efc01c7831ec659146f63b97df Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Nov 2025 19:58:12 +0000 Subject: [PATCH 15/16] rebase Signed-off-by: Bill Nell --- .../layers/quantization/modelopt.py | 24 +------------------ 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index f61d2a52925d..21281db60cb5 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1718,34 +1718,12 @@ def apply( workspace=layer.workspace, ) - elif ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS - ): - from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 - flashinfer_cutlass_moe_fp4, - ) - - assert self.moe_quant_config is not None - - return flashinfer_cutlass_moe_fp4( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - quant_config=self.moe_quant_config, - inplace=False, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) else: # If no modular kernel is provided, use cutlass_moe_fp4 for TP case # only (no EP). from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 + assert self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS assert self.moe_quant_config is not None return cutlass_moe_fp4( a=x, From a0621a365ab48caf7f2bc6ed73ff982eabdf6fe5 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Mon, 3 Nov 2025 21:50:12 +0000 Subject: [PATCH 16/16] fix merge Signed-off-by: Bill Nell --- .../layers/quantization/modelopt.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 21281db60cb5..f61d2a52925d 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1718,12 +1718,34 @@ def apply( workspace=layer.workspace, ) + elif ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + ): + from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 + flashinfer_cutlass_moe_fp4, + ) + + assert self.moe_quant_config is not None + + return flashinfer_cutlass_moe_fp4( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + quant_config=self.moe_quant_config, + inplace=False, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ) else: # If no modular kernel is provided, use cutlass_moe_fp4 for TP case # only (no EP). from vllm.model_executor.layers.fused_moe.cutlass_moe import cutlass_moe_fp4 - assert self.flashinfer_moe_backend != FlashinferMoeBackend.CUTLASS assert self.moe_quant_config is not None return cutlass_moe_fp4( a=x,