2222# See the License for the specific language governing permissions and
2323# limitations under the License.
2424"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
25- from collections .abc import Iterable
25+ import typing
26+ from collections .abc import Callable , Iterable
2627from typing import Any , Optional , Union
2728
2829import torch
3132
3233from vllm .attention import Attention
3334from vllm .compilation .decorators import support_torch_compile
34- from vllm .config import CacheConfig , VllmConfig
35- from vllm .distributed import get_pp_group , get_tensor_model_parallel_world_size
35+ from vllm .config import CacheConfig , VllmConfig , get_current_vllm_config
36+ from vllm .distributed import (get_ep_group , get_pp_group ,
37+ get_tensor_model_parallel_world_size )
3638from vllm .logger import init_logger
3739from vllm .model_executor .layers .activation import SiluAndMul
3840from vllm .model_executor .layers .fused_moe import FusedMoE
5052from vllm .model_executor .sampling_metadata import SamplingMetadata
5153from vllm .sequence import IntermediateTensors
5254
53- from .interfaces import SupportsLoRA , SupportsPP
54- from .utils import (AutoWeightsLoader , extract_layer_index ,
55+ from .interfaces import MixtureOfExperts , SupportsLoRA , SupportsPP
56+ from .utils import (AutoWeightsLoader , PPMissingLayer , extract_layer_index ,
5557 is_pp_missing_parameter ,
5658 make_empty_intermediate_tensors_factory , make_layers ,
5759 maybe_prefix )
@@ -101,23 +103,47 @@ def __init__(
101103 config : PretrainedConfig ,
102104 quant_config : Optional [QuantizationConfig ] = None ,
103105 prefix : str = "" ,
106+ enable_eplb : bool = False ,
104107 ):
105108 super ().__init__ ()
106109 self .tp_size = get_tensor_model_parallel_world_size ()
107110
111+ self .ep_group = get_ep_group ().device_group
112+ self .ep_rank = self .ep_group .rank ()
113+ self .ep_size = self .ep_group .size ()
114+ self .n_routed_experts = config .num_experts
115+
108116 if self .tp_size > config .num_experts :
109117 raise ValueError (
110118 f"Tensor parallel size { self .tp_size } is greater than "
111119 f"the number of experts { config .num_experts } ." )
112120
113- self .experts = FusedMoE (num_experts = config .num_experts ,
121+ # Load balancing settings.
122+ vllm_config = get_current_vllm_config ()
123+ parallel_config = vllm_config .parallel_config
124+ self .enable_eplb = enable_eplb
125+
126+ self .n_logical_experts = self .n_routed_experts
127+ self .n_redundant_experts = parallel_config .num_redundant_experts
128+ self .n_physical_experts = (self .n_logical_experts +
129+ self .n_redundant_experts )
130+ self .n_local_physical_experts = self .n_physical_experts // self .ep_size
131+
132+ self .physical_expert_start = (self .ep_rank *
133+ self .n_local_physical_experts )
134+ self .physical_expert_end = (self .physical_expert_start +
135+ self .n_local_physical_experts )
136+
137+ self .experts = FusedMoE (num_experts = self .n_routed_experts ,
114138 top_k = config .num_experts_per_tok ,
115139 hidden_size = config .hidden_size ,
116140 intermediate_size = config .moe_intermediate_size ,
117141 reduce_results = False ,
118142 renormalize = config .norm_topk_prob ,
119143 quant_config = quant_config ,
120- prefix = f"{ prefix } .experts" )
144+ prefix = f"{ prefix } .experts" ,
145+ enable_eplb = self .enable_eplb ,
146+ num_redundant_experts = self .n_redundant_experts )
121147
122148 self .gate = ReplicatedLinear (config .hidden_size ,
123149 config .num_experts ,
@@ -246,6 +272,7 @@ def __init__(
246272 cache_config : Optional [CacheConfig ] = None ,
247273 quant_config : Optional [QuantizationConfig ] = None ,
248274 prefix : str = "" ,
275+ enable_eplb : bool = False ,
249276 ) -> None :
250277 super ().__init__ ()
251278 self .hidden_size = config .hidden_size
@@ -277,7 +304,8 @@ def __init__(
277304 (layer_idx + 1 ) % config .decoder_sparse_step == 0 ):
278305 self .mlp = Qwen3MoeSparseMoeBlock (config = config ,
279306 quant_config = quant_config ,
280- prefix = f"{ prefix } .mlp" )
307+ prefix = f"{ prefix } .mlp" ,
308+ enable_eplb = enable_eplb )
281309 else :
282310 self .mlp = Qwen3MoeMLP (hidden_size = config .hidden_size ,
283311 intermediate_size = config .intermediate_size ,
@@ -323,6 +351,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
323351 config = vllm_config .model_config .hf_config
324352 cache_config = vllm_config .cache_config
325353 quant_config = vllm_config .quant_config
354+ parallel_config = vllm_config .parallel_config
355+ enable_eplb = parallel_config .enable_eplb
356+ self .num_redundant_experts = parallel_config .num_redundant_experts
326357
327358 self .padding_idx = config .pad_token_id
328359 self .vocab_size = config .vocab_size
@@ -336,7 +367,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
336367 lambda prefix : Qwen3MoeDecoderLayer (config = config ,
337368 cache_config = cache_config ,
338369 quant_config = quant_config ,
339- prefix = prefix ),
370+ prefix = prefix ,
371+ enable_eplb = enable_eplb ),
340372 prefix = f"{ prefix } .layers" ,
341373 )
342374 self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
@@ -382,7 +414,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
382414 ckpt_gate_proj_name = "gate_proj" ,
383415 ckpt_down_proj_name = "down_proj" ,
384416 ckpt_up_proj_name = "up_proj" ,
385- num_experts = self .config .num_experts )
417+ num_experts = self .config .num_experts ,
418+ num_redundant_experts = self .num_redundant_experts )
386419
387420 def load_weights (self , weights : Iterable [tuple [str ,
388421 torch .Tensor ]]) -> set [str ]:
@@ -433,27 +466,51 @@ def load_weights(self, weights: Iterable[tuple[str,
433466 weight_loader (param , loaded_weight , shard_id )
434467 break
435468 else :
469+ is_expert_weight = False
436470 for mapping in expert_params_mapping :
437471 param_name , weight_name , expert_id , shard_id = mapping
438472 if weight_name not in name :
439473 continue
440- name = name .replace (weight_name , param_name )
441- # Skip layers on other devices.
442- if is_pp_missing_parameter (name , self ):
474+
475+ # Anyway, this is an expert weight and should not be
476+ # attempted to load as other weights later
477+ is_expert_weight = True
478+
479+ # Do not modify `name` since the loop may continue here
480+ # Instead, create a new variable
481+ name_mapped = name .replace (weight_name , param_name )
482+
483+ if is_pp_missing_parameter (name_mapped , self ):
443484 continue
485+
444486 # Skip loading extra parameters for GPTQ/modelopt models.
445- if name .endswith (
446- ignore_suffixes ) and name not in params_dict :
487+ if name_mapped .endswith (
488+ ignore_suffixes
489+ ) and name_mapped not in params_dict :
447490 continue
448- param = params_dict [name ]
449- weight_loader = param .weight_loader
450- weight_loader (param ,
451- loaded_weight ,
452- name ,
453- shard_id = shard_id ,
454- expert_id = expert_id )
455- break
491+
492+ param = params_dict [name_mapped ]
493+ # We should ask the weight loader to return success or not
494+ # here since otherwise we may skip experts with other
495+ # available replicas.
496+ weight_loader = typing .cast (Callable [..., bool ],
497+ param .weight_loader )
498+ success = weight_loader (param ,
499+ loaded_weight ,
500+ name_mapped ,
501+ shard_id = shard_id ,
502+ expert_id = expert_id ,
503+ return_success = True )
504+ if success :
505+ name = name_mapped
506+ break
456507 else :
508+ if is_expert_weight :
509+ # We've checked that this is an expert weight
510+ # However it's not mapped locally to this rank
511+ # So we simply skip it
512+ continue
513+
457514 # Skip loading extra parameters for GPTQ/modelopt models.
458515 if name .endswith (
459516 ignore_suffixes ) and name not in params_dict :
@@ -482,7 +539,8 @@ def load_weights(self, weights: Iterable[tuple[str,
482539 return loaded_params
483540
484541
485- class Qwen3MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA ):
542+ class Qwen3MoeForCausalLM (nn .Module , SupportsPP , SupportsLoRA ,
543+ MixtureOfExperts ):
486544 packed_modules_mapping = {
487545 "qkv_proj" : [
488546 "q_proj" ,
@@ -514,6 +572,66 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
514572 self .make_empty_intermediate_tensors = (
515573 self .model .make_empty_intermediate_tensors )
516574
575+ # Set MoE hyperparameters
576+ self .expert_weights = []
577+
578+ self .moe_layers : list [FusedMoE ] = []
579+ example_layer = None
580+ for layer in self .model .layers :
581+ if isinstance (layer , PPMissingLayer ):
582+ continue
583+
584+ assert isinstance (layer , Qwen3MoeDecoderLayer )
585+ if isinstance (layer .mlp , Qwen3MoeSparseMoeBlock ):
586+ example_layer = layer .mlp
587+ self .moe_layers .append (layer .mlp .experts )
588+
589+ if example_layer is None :
590+ raise RuntimeError ("No Qwen3MoE layer found in the model.layers." )
591+
592+ self .num_moe_layers = len (self .moe_layers )
593+ self .num_expert_groups = 1
594+ self .num_shared_experts = 0
595+ self .num_logical_experts = example_layer .n_logical_experts
596+ self .num_physical_experts = example_layer .n_physical_experts
597+ self .num_local_physical_experts = example_layer .n_local_physical_experts
598+ self .num_routed_experts = example_layer .n_routed_experts
599+ self .num_redundant_experts = example_layer .n_redundant_experts
600+
601+ def set_eplb_state (
602+ self ,
603+ expert_load_view : torch .Tensor ,
604+ logical_to_physical_map : torch .Tensor ,
605+ logical_replica_count : torch .Tensor ,
606+ ) -> None :
607+ for layer_idx , layer in enumerate (self .moe_layers ):
608+ # Register the expert weights.
609+ self .expert_weights .append (layer .get_expert_weights ())
610+ layer .set_eplb_state (
611+ moe_layer_idx = layer_idx ,
612+ expert_load_view = expert_load_view ,
613+ logical_to_physical_map = logical_to_physical_map ,
614+ logical_replica_count = logical_replica_count ,
615+ )
616+
617+ def update_physical_experts_metadata (
618+ self ,
619+ num_physical_experts : int ,
620+ num_local_physical_experts : int ,
621+ ) -> None :
622+ assert self .num_local_physical_experts == num_local_physical_experts
623+ self .num_physical_experts = num_physical_experts
624+ self .num_local_physical_experts = num_local_physical_experts
625+ self .num_redundant_experts = (num_physical_experts -
626+ self .num_logical_experts )
627+ for layer in self .model .layers :
628+ if isinstance (layer .mlp , Qwen3MoeSparseMoeBlock ):
629+ moe = layer .mlp
630+ moe .n_local_physical_experts = num_local_physical_experts
631+ moe .n_physical_experts = num_physical_experts
632+ moe .n_redundant_experts = self .num_redundant_experts
633+ moe .experts .update_expert_map ()
634+
517635 def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
518636 return self .model .get_input_embeddings (input_ids )
519637
0 commit comments