3030from vllm .platforms import current_platform
3131from vllm .utils import direct_register_custom_op
3232
33+ from .interfaces import MixtureOfExperts
3334from .transformers import (
3435 TransformersBase ,
3536 TransformersForCausalLM ,
@@ -116,17 +117,41 @@ def transformers_moe_forward_fake(
116117)
117118
118119
119- class TransformersMoEBase (TransformersBase ):
120+ class TransformersMoEBase (TransformersBase , MixtureOfExperts ):
120121 def __init__ (self , * , vllm_config , prefix = "" ):
121122 self .check_version ("4.57.0.dev0" , "MoE models support" )
123+ self .ep_group = get_ep_group ()
122124 super ().__init__ (vllm_config = vllm_config , prefix = prefix )
123125
124- if self .parallel_config .enable_eplb :
125- raise NotImplementedError (
126- "Transformers backend does not support expert parallel load "
127- "balancing yet."
126+ def set_eplb_state (
127+ self ,
128+ expert_load_view : torch .Tensor ,
129+ logical_to_physical_map : torch .Tensor ,
130+ logical_replica_count : torch .Tensor ,
131+ ):
132+ for moe_layer_idx , mlp_layer in enumerate (self .mlp_layers ):
133+ mlp_layer .experts .set_eplb_state (
134+ moe_layer_idx = moe_layer_idx ,
135+ expert_load_view = expert_load_view ,
136+ logical_to_physical_map = logical_to_physical_map ,
137+ logical_replica_count = logical_replica_count ,
128138 )
129139
140+ def update_physical_experts_metadata (
141+ self ,
142+ num_physical_experts : int ,
143+ num_local_physical_experts : int ,
144+ ):
145+ assert self .num_local_physical_experts == num_local_physical_experts
146+ self .num_physical_experts = num_physical_experts
147+ self .num_local_physical_experts = num_local_physical_experts
148+ self .num_redundant_experts = num_physical_experts - self .num_logical_experts
149+ for mlp in self .mlp_layers :
150+ mlp .n_local_physical_experts = num_local_physical_experts
151+ mlp .n_physical_experts = num_physical_experts
152+ mlp .n_redundant_experts = self .num_redundant_experts
153+ mlp .experts .update_expert_map ()
154+
130155 def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
131156 """
132157 Params for weights, fp8 weight scales, fp8 activation scales
@@ -138,15 +163,17 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
138163 ("w1" , "w2" , "w3" ), # Granite, Mixtral, Phi MoE style
139164 ("linear" , "linear_1" , "linear_v" ), # Grok1 style
140165 ]
166+ num_experts = self .model_config .get_num_experts ()
167+ num_redundant_experts = self .parallel_config .eplb_config .num_redundant_experts
141168 expert_mapping = []
142169 for gate_proj , down_proj , up_proj in ckpt_names :
143170 expert_mapping .extend (
144171 FusedMoE .make_expert_params_mapping (
145172 ckpt_gate_proj_name = gate_proj ,
146173 ckpt_down_proj_name = down_proj ,
147174 ckpt_up_proj_name = up_proj ,
148- num_experts = self . model_config . get_num_experts () ,
149- num_redundant_experts = 0 , # TODO: enable EPLB
175+ num_experts = num_experts ,
176+ num_redundant_experts = num_redundant_experts ,
150177 )
151178 )
152179 return expert_mapping
@@ -167,12 +194,15 @@ def recursive_replace(self):
167194
168195 # If there are shared experts, the results are
169196 # reduced after mlp.forward() not inside FusedMoE
170- num_experts_shared = getattr_iter (
197+ num_shared_experts = getattr_iter (
171198 text_config ,
172- ["num_experts_shared" , "n_shared_experts" , "moe_num_shared_experts" ],
199+ [
200+ "n_shared_experts" , # DeepSeek, Docs, GLM
201+ "moe_num_shared_experts" , # Aria, Ernie
202+ ],
173203 0 ,
174204 )
175- reduce_results = num_experts_shared == 0
205+ reduce_results = num_shared_experts == 0
176206
177207 def add_all_reduce (mlp : nn .Module ):
178208 """Adds an all-reduce to the output of `mlp.forward()`."""
@@ -207,13 +237,23 @@ def forward(self, *args, **kwargs):
207237 # Expert mapping for `AutoWeightsLoader`
208238 expert_mapping = self .get_expert_mapping ()
209239
210- # Configs
211- parallel_config = self .parallel_config
212- eplb_config = parallel_config .eplb_config
213-
214240 # Expert parallel load balancing kwargs
215- enable_eplb = parallel_config .enable_eplb
216- num_redundant_experts = eplb_config .num_redundant_experts
241+ enable_eplb = self .parallel_config .enable_eplb
242+ num_redundant_experts = self .parallel_config .eplb_config .num_redundant_experts
243+
244+ # MixtureOfExperts mixin settings
245+ ep_size = self .ep_group .world_size
246+
247+ self .mlp_layers = [] # Used for MixtureOfExperts methods
248+ self .expert_weights = []
249+ self .num_moe_layers = 0
250+ self .num_expert_groups = 1 if num_expert_group is None else num_expert_group
251+ self .num_logical_experts = num_experts
252+ self .num_physical_experts = num_experts + num_redundant_experts
253+ self .num_local_physical_experts = self .num_physical_experts // ep_size
254+ self .num_routed_experts = num_experts
255+ self .num_shared_experts = num_shared_experts
256+ self .num_redundant_experts = num_redundant_experts
217257
218258 # Recursively fuse MoE layers
219259 def _recursive_replace (module : nn .Module , prefix : str ):
@@ -235,6 +275,9 @@ def _recursive_replace(module: nn.Module, prefix: str):
235275 for mlp_param_name , _ in mlp .named_parameters ():
236276 if "shared_expert" in mlp_param_name :
237277 reduce_results = False
278+ # If the config does not specify num_shared_experts, but
279+ # the model has shared experts, we assume there is one.
280+ self .num_shared_experts = 1
238281 break
239282 # Replace experts module with FusedMoE
240283 fused_experts = TransformersFusedMoE (
@@ -258,6 +301,10 @@ def _recursive_replace(module: nn.Module, prefix: str):
258301 )
259302 mlp .experts = fused_experts
260303 log_replacement (qual_name , experts , fused_experts )
304+ # Update MixtureOfExperts mixin state
305+ self .mlp_layers .append (mlp )
306+ self .expert_weights .append (fused_experts .get_expert_weights ())
307+ self .num_moe_layers += 1
261308 # If results are not all-reduced in FusedMoE, ensure they
262309 # are all-reduced at the end of mlp.forward() if tensor
263310 # parallel or expert parallel is enabled
0 commit comments