diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py index a5f40e4fce28..25cd40a939bb 100644 --- a/vllm/model_executor/models/transformers.py +++ b/vllm/model_executor/models/transformers.py @@ -40,7 +40,7 @@ ) from vllm.config.multimodal import BaseDummyOptions from vllm.config.utils import getattr_iter -from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tp_group from vllm.distributed.utils import get_pp_indices from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm @@ -506,9 +506,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.quant_config: Optional[QuantizationConfig] = vllm_config.quant_config self.pp_group = get_pp_group() - self.pp_size = self.pp_group.world_size - self.pp_rank = self.pp_group.rank_in_group - self.tp_size = get_tensor_model_parallel_world_size() + self.tp_group = get_tp_group() # Weights to skip in `self.load_weights` self.skip_prefixes: list[str] = [] @@ -576,7 +574,7 @@ def pipeline_parallel(self): """ Apply the model's pipeline parallelization plan. """ - if self.pp_size <= 1: + if self.pp_group.world_size <= 1: return if not self.model.supports_pp_plan: @@ -613,7 +611,9 @@ def pipeline_parallel(self): # Module list start_layer, end_layer = get_pp_indices( - self.text_config.num_hidden_layers, self.pp_rank, self.pp_size + self.text_config.num_hidden_layers, + self.pp_group.rank_in_group, + self.pp_group.world_size, ) layers_name = pp_plan[module_list_idx] layers = getattr(self.model, layers_name) @@ -638,7 +638,7 @@ def recursive_replace(self): """ tp_plan = self.model.tp_plan - if not tp_plan and self.tp_size > 1: + if not tp_plan and self.tp_group.world_size > 1: tip = get_feature_request_tip( self.model_config.model, self.model_config.trust_remote_code ) @@ -687,7 +687,9 @@ def create_attention_instances( head_size = self.model_config.get_head_size() num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config) start, end = get_pp_indices( - self.text_config.num_hidden_layers, self.pp_rank, self.pp_size + self.text_config.num_hidden_layers, + self.pp_group.rank_in_group, + self.pp_group.world_size, ) attention_instances = {} @@ -749,7 +751,7 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: - if not get_pp_group().is_first_rank: + if not self.pp_group.is_first_rank: assert intermediate_tensors is not None input_ids = None inputs_embeds = intermediate_tensors["hidden_states"] @@ -773,7 +775,7 @@ def forward( return_dict=False, )[0][0, ...] # we remove batch dimension for now - if not get_pp_group().is_last_rank: + if not self.pp_group.is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) return hidden_states @@ -811,7 +813,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): if self.text_config.tie_word_embeddings: self.skip_prefixes.append("lm_head.") - if get_pp_group().is_last_rank: + if self.pp_group.is_last_rank: self.unpadded_vocab_size = self.text_config.vocab_size self.lm_head = ParallelLMHead( self.text_config.vocab_size, diff --git a/vllm/model_executor/models/transformers_moe.py b/vllm/model_executor/models/transformers_moe.py index 9fba80dd1db8..eb135f050e8c 100644 --- a/vllm/model_executor/models/transformers_moe.py +++ b/vllm/model_executor/models/transformers_moe.py @@ -30,6 +30,7 @@ from vllm.platforms import current_platform from vllm.utils import direct_register_custom_op +from .interfaces import MixtureOfExperts from .transformers import ( TransformersBase, TransformersForCausalLM, @@ -116,17 +117,41 @@ def transformers_moe_forward_fake( ) -class TransformersMoEBase(TransformersBase): +class TransformersMoEBase(TransformersBase, MixtureOfExperts): def __init__(self, *, vllm_config, prefix=""): self.check_version("4.57.0.dev0", "MoE models support") + self.ep_group = get_ep_group() super().__init__(vllm_config=vllm_config, prefix=prefix) - if self.parallel_config.enable_eplb: - raise NotImplementedError( - "Transformers backend does not support expert parallel load " - "balancing yet." + def set_eplb_state( + self, + expert_load_view: torch.Tensor, + logical_to_physical_map: torch.Tensor, + logical_replica_count: torch.Tensor, + ): + for moe_layer_idx, mlp_layer in enumerate(self.mlp_layers): + mlp_layer.experts.set_eplb_state( + moe_layer_idx=moe_layer_idx, + expert_load_view=expert_load_view, + logical_to_physical_map=logical_to_physical_map, + logical_replica_count=logical_replica_count, ) + def update_physical_experts_metadata( + self, + num_physical_experts: int, + num_local_physical_experts: int, + ): + assert self.num_local_physical_experts == num_local_physical_experts + self.num_physical_experts = num_physical_experts + self.num_local_physical_experts = num_local_physical_experts + self.num_redundant_experts = num_physical_experts - self.num_logical_experts + for mlp in self.mlp_layers: + mlp.n_local_physical_experts = num_local_physical_experts + mlp.n_physical_experts = num_physical_experts + mlp.n_redundant_experts = self.num_redundant_experts + mlp.experts.update_expert_map() + def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: """ Params for weights, fp8 weight scales, fp8 activation scales @@ -138,6 +163,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ("w1", "w2", "w3"), # Granite, Mixtral, Phi MoE style ("linear", "linear_1", "linear_v"), # Grok1 style ] + num_experts = self.model_config.get_num_experts() + num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts expert_mapping = [] for gate_proj, down_proj, up_proj in ckpt_names: expert_mapping.extend( @@ -145,8 +172,8 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: ckpt_gate_proj_name=gate_proj, ckpt_down_proj_name=down_proj, ckpt_up_proj_name=up_proj, - num_experts=self.model_config.get_num_experts(), - num_redundant_experts=0, # TODO: enable EPLB + num_experts=num_experts, + num_redundant_experts=num_redundant_experts, ) ) return expert_mapping @@ -167,12 +194,15 @@ def recursive_replace(self): # If there are shared experts, the results are # reduced after mlp.forward() not inside FusedMoE - num_experts_shared = getattr_iter( + num_shared_experts = getattr_iter( text_config, - ["num_experts_shared", "n_shared_experts", "moe_num_shared_experts"], + [ + "n_shared_experts", # DeepSeek, Docs, GLM + "moe_num_shared_experts", # Aria, Ernie + ], 0, ) - reduce_results = num_experts_shared == 0 + reduce_results = num_shared_experts == 0 def add_all_reduce(mlp: nn.Module): """Adds an all-reduce to the output of `mlp.forward()`.""" @@ -207,13 +237,23 @@ def forward(self, *args, **kwargs): # Expert mapping for `AutoWeightsLoader` expert_mapping = self.get_expert_mapping() - # Configs - parallel_config = self.parallel_config - eplb_config = parallel_config.eplb_config - # Expert parallel load balancing kwargs - enable_eplb = parallel_config.enable_eplb - num_redundant_experts = eplb_config.num_redundant_experts + enable_eplb = self.parallel_config.enable_eplb + num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts + + # MixtureOfExperts mixin settings + ep_size = self.ep_group.world_size + + self.mlp_layers = [] # Used for MixtureOfExperts methods + self.expert_weights = [] + self.num_moe_layers = 0 + self.num_expert_groups = 1 if num_expert_group is None else num_expert_group + self.num_logical_experts = num_experts + self.num_physical_experts = num_experts + num_redundant_experts + self.num_local_physical_experts = self.num_physical_experts // ep_size + self.num_routed_experts = num_experts + self.num_shared_experts = num_shared_experts + self.num_redundant_experts = num_redundant_experts # Recursively fuse MoE layers def _recursive_replace(module: nn.Module, prefix: str): @@ -235,6 +275,9 @@ def _recursive_replace(module: nn.Module, prefix: str): for mlp_param_name, _ in mlp.named_parameters(): if "shared_expert" in mlp_param_name: reduce_results = False + # If the config does not specify num_shared_experts, but + # the model has shared experts, we assume there is one. + self.num_shared_experts = 1 break # Replace experts module with FusedMoE fused_experts = TransformersFusedMoE( @@ -258,6 +301,10 @@ def _recursive_replace(module: nn.Module, prefix: str): ) mlp.experts = fused_experts log_replacement(qual_name, experts, fused_experts) + # Update MixtureOfExperts mixin state + self.mlp_layers.append(mlp) + self.expert_weights.append(fused_experts.get_expert_weights()) + self.num_moe_layers += 1 # If results are not all-reduced in FusedMoE, ensure they # are all-reduced at the end of mlp.forward() if tensor # parallel or expert parallel is enabled