diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 9469e99999..f62b17cc96 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -75,7 +75,7 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig, MultiStreamStepMetadata, make_multistream_metadata_ds) -from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.ops.common_fused_moe import AscendFusedMoE from vllm_ascend.utils import dispose_tensor VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 59273ed88a..8e605c6ebd 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -70,7 +70,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.models.layers.mla import AscendMLAModules -from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.ops.common_fused_moe import AscendFusedMoE from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import dispose_tensor diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index 33073f4301..53a401efdc 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -14,8 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -from typing import Callable, Optional +import os +from typing import Any, Callable, Optional import torch import torch_npu @@ -24,268 +24,232 @@ tensor_model_parallel_all_reduce) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import \ - FusedMoEParallelConfig # isort: skip + FusedMoEConfig # isort: skip from vllm.model_executor.layers.fused_moe.layer import ( - FusedMoE, UnquantizedFusedMoEMethod) + FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) +from vllm.model_executor.layers.quantization.base_config import \ + QuantizationConfig from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group +from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.moe.experts_selector import select_experts from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl, NaiveMulticastCommImpl) -from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p - -original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__ - - -def fused_experts_moge( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - moe_parallel_config: FusedMoEParallelConfig, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - global_num_experts: int, - expert_map: torch.Tensor = None, - apply_router_weight_on_input: bool = False, -) -> torch.Tensor: - """ - - Args: - hidden_states: Hidden states of shape (num_tokens, hidden_size). - w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). - w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). - topk_weights: Routing weights of shape (num_tokens, top_k). - topk_ids: Selected expert IDs of shape (num_tokens, top_k). - top_k: Number of experts to select. - expert_map: Expert mapping of shape (num_experts,). - - Returns: - hidden_states: Hidden states after routing. - """ - ep_size = moe_parallel_config.ep_size - local_num_experts = global_num_experts // ep_size - local_num_group = top_k // ep_size - - bsz, _ = hidden_states.shape - flatten_topk_ids = topk_ids.view(-1) - sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) - sorted_topk_ids = sorted_topk_ids.to(torch.int32) - sorted_hidden_states = hidden_states.index_select( - 0, sorted_topk_ids // local_num_group) - - experts_id = torch.arange(0, - local_num_experts, - dtype=topk_ids.dtype, - device=topk_ids.device) - num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to( - torch.float32).sum(0) - topk_scales = topk_weights.view(-1).index_select( - 0, sorted_topk_ids).unsqueeze(-1) - group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) - - gate_up_out = torch_npu.npu_grouped_matmul( - x=[sorted_hidden_states], - weight=[w1], - split_item=2, - group_list_type=0, - group_type=0, - group_list=group_list, - )[0] - - if is_310p(): - gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( - torch.float16) - else: - gate_up_out = torch_npu.npu_swiglu(gate_up_out) - gate_up_out *= topk_scales - - down_out_list = torch_npu.npu_grouped_matmul( - x=[gate_up_out], - weight=[w2], - split_item=2, - group_list_type=0, - group_type=0, - group_list=group_list, - )[0] - - unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) - unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids) - final_hidden_states = unsorted_hidden_states.reshape( - bsz, top_k // ep_size, -1).sum(1) - - return final_hidden_states - - -def unquantized_fused_moe_init_func(self, *args, **kwargs): - original_unquantized_fused_moe_init_func(self, *args, **kwargs) - - # NOTE: Currently, this self.use_aclgraph is only used in - # UnquantizedFusedMoEMethod.forward_oot to decide whether to use in - # ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue. - # Once torch.randint_like is supported or removed, this flag can be removed. - vllm_config = get_current_vllm_config() - ascend_config = get_ascend_config() - if ascend_config.torchair_graph_config.enabled: - self.use_aclgraph = False - else: - self.use_aclgraph = (vllm_config.compilation_config.level - == CompilationLevel.PIECEWISE - and not vllm_config.model_config.enforce_eager) - self.transpose = True - - -def forward_oot( - self, - layer: torch.nn.Module, - x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, - router_logits: torch.Tensor, - renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, - custom_routing_function: Optional[Callable] = None, - scoring_func: str = "softmax", - routed_scaling_factor: float = 1.0, - e_score_correction_bias: Optional[torch.Tensor] = None, - global_num_experts: int = -1, - expert_map: Optional[torch.Tensor] = None, - apply_router_weight_on_input: bool = False, - activation: str = "silu", - enable_eplb: bool = False, - expert_load_view: Optional[torch.Tensor] = None, - logical_to_physical_map: Optional[torch.Tensor] = None, - logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor: - - topk_weights, topk_ids, row_idx = select_experts( - hidden_states=x, - router_logits=router_logits, - top_k=top_k, - use_grouped_topk=use_grouped_topk, - renormalize=renormalize, - topk_group=topk_group, - num_expert_group=num_expert_group, - custom_routing_function=custom_routing_function, - scoring_func=scoring_func, - routed_scaling_factor=routed_scaling_factor, - e_score_correction_bias=e_score_correction_bias, - global_num_experts=global_num_experts) - - if topk_ids.shape[1] < top_k or is_310p(): - assert global_num_experts is not None - return fused_experts_moge( +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, + get_all_reduce_merge_state, + get_rm_router_logits_state, is_310p) + + +class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): + + def __init__(self, moe: FusedMoEConfig = None): + + super().__init__(moe=moe) + + # NOTE: Currently, this self.use_aclgraph is only used in + # UnquantizedFusedMoEMethod.forward_oot to decide whether to use in + # ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue. + # Once torch.randint_like is supported or removed, this flag can be removed. + vllm_config = get_current_vllm_config() + ascend_config = get_ascend_config() + if ascend_config.torchair_graph_config.enabled: + self.use_aclgraph = False + else: + self.use_aclgraph = (vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not vllm_config.model_config.enforce_eager) + self.transpose = True + + def process_weights_after_loading(self, layer): + super(UnquantizedFusedMoEMethod, + self).process_weights_after_loading(layer) + if self.transpose: + w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose( + 1, 2).contiguous() + layer.w13_weight = torch.nn.Parameter(w13_data, + requires_grad=False) + + w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( + 1, 2).contiguous() + layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) + + self.transpose = False + else: + w13_data = self._maybe_pad_weight(layer.w13_weight.data) + layer.w13_weight = torch.nn.Parameter(w13_data, + requires_grad=False) + + w2_data = self._maybe_pad_weight(layer.w2_weight.data) + layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) + + if not is_310p(): + layer.w13_weight.data = torch_npu.npu_format_cast( + layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) + layer.w2_weight.data = torch_npu.npu_format_cast( + layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + use_grouped_topk: bool, + top_k: int, + router_logits: torch.Tensor, + renormalize: bool, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + routed_scaling_factor: float = 1.0, + e_score_correction_bias: Optional[torch.Tensor] = None, + global_num_experts: int = -1, + expert_map: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + enable_force_load_balance: bool = False, + shared_experts: Optional[Any] = None, + **kwargs) -> torch.Tensor: + + topk_weights, topk_ids, row_idx = select_experts( + hidden_states=x, + router_logits=router_logits, + top_k=top_k, + use_grouped_topk=use_grouped_topk, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + global_num_experts=global_num_experts) + + # this is a naive implementation for experts load balance so as + # to avoid accumulating too much tokens on a single rank. + # currently it is only activated when doing profile runs. + if enable_force_load_balance and not self.use_aclgraph: + topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) + + moe_comm_method = get_forward_context().moe_comm_method + return moe_comm_method.fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, - moe_parallel_config=self.moe.moe_parallel_config, topk_weights=topk_weights, topk_ids=topk_ids, - top_k=top_k, + row_idx=row_idx, global_num_experts=global_num_experts, expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input) - - moe_comm_method = get_forward_context().moe_comm_method - return moe_comm_method.fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - row_idx=row_idx, - global_num_experts=global_num_experts, - expert_map=expert_map) - - -def process_weights_after_loading(self, layer): - super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer) - if self.transpose: - w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose( - 1, 2).contiguous() - layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) - - w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( - 1, 2).contiguous() - layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) - - self.transpose = False - else: - w13_data = self._maybe_pad_weight(layer.w13_weight.data) - layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) - - w2_data = self._maybe_pad_weight(layer.w2_weight.data) - layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) - - if not is_310p(): - layer.w13_weight.data = torch_npu.npu_format_cast( - layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ) - layer.w2_weight.data = torch_npu.npu_format_cast( - layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ) + shared_experts=shared_experts) class AscendFusedMoE(FusedMoE): + # The moe_counter parameter is required during the initialization of EPLB + # to identify the current layer index within the MOE model. + moe_counter = -1 def __init__( self, - num_experts, - top_k, - hidden_size, - intermediate_size, - params_dtype=None, - reduce_results=False, - renormalize=True, - use_grouped_topk=False, - num_expert_group=None, - topk_group=None, - quant_config=None, - tp_size=None, - ep_size=None, - dp_size=None, - prefix="", - custom_routing_function=None, - scoring_func="softmax", + num_experts: int, # Global number of experts + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = False, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + ep_size: Optional[int] = None, + dp_size: Optional[int] = None, + prefix: str = "", + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, - e_score_correction_bias=None, - apply_router_weight_on_input=False, - activation="silu", - enable_eplb=False, - num_redundant_experts=0, - has_bias=False, + e_score_correction_bias: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False, + activation: str = "silu", + enable_eplb: bool = False, + num_redundant_experts: int = 0, + has_bias: bool = False, ): super().__init__( - num_experts, - top_k, - hidden_size, - intermediate_size, - params_dtype, - reduce_results, - renormalize, - use_grouped_topk, - num_expert_group, - topk_group, - quant_config, - tp_size, - ep_size, - dp_size, - prefix, - custom_routing_function, - scoring_func, - routed_scaling_factor, - e_score_correction_bias, - apply_router_weight_on_input, - activation, - enable_eplb, - num_redundant_experts, - has_bias, + num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=reduce_results, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + num_expert_group=num_expert_group, + topk_group=topk_group, + quant_config=quant_config, + tp_size=tp_size, + ep_size=ep_size, + dp_size=dp_size, + prefix=prefix, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + apply_router_weight_on_input=apply_router_weight_on_input, + activation=activation, + enable_eplb=enable_eplb, + num_redundant_experts=num_redundant_experts, + has_bias=has_bias, ) + AscendFusedMoE.moe_counter += 1 + self.moe_instance_id = AscendFusedMoE.moe_counter + + self.expert_map = None + self.log2phy = None + self.global_redundant_expert_num = 0 + + is_deepseek_v3_r1 = self.global_num_experts == 256 + self.rm_router_logits = get_rm_router_logits_state( + self.moe_parallel_config.ep_size, self.dp_size, is_deepseek_v3_r1) + self.all_reduce_merge = get_all_reduce_merge_state( + self.moe_parallel_config.ep_size, is_deepseek_v3_r1) + + ascend_config = get_ascend_config() + expert_map_path = ascend_config.expert_map_path + if expert_map_path and os.path.exists(expert_map_path): + # moe expert load balance + expert_load_balancer = ExpertLoadBalancer(expert_map_path, + self.global_num_experts) + self.local_num_experts, self.expert_map = \ + expert_load_balancer.get_rank_placement_map( + self.moe_instance_id, + get_ep_group().rank_in_group) + self.log2phy = expert_load_balancer.get_rank_log2phy_map( + self.moe_instance_id, + get_ep_group().rank_in_group) + self.global_redundant_expert_num = \ + expert_load_balancer.get_global_redundant_expert_num() + else: + # Create a tensor of size num_experts filled with -1 + self.local_num_experts, self.expert_map = determine_expert_map( + self.ep_size, + get_ep_group().rank_in_group, self.global_num_experts) + + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + if quant_config is None: + self.quant_method = AscendUnquantizedFusedMoEMethod( + self.moe_config) + else: + self.quant_method = quant_config.get_quant_method(self, prefix) + + assert self.quant_method is not None - self.hidden_size = hidden_size self.moe_config.tp_group = get_tp_group() self.moe_config.dp_group = get_dp_group() self.moe_config.ep_group = get_ep_group() self.moe_config.mc2_group = get_mc2_group() + self.moe_config.num_global_redundant_experts = self.global_redundant_expert_num for method in { AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl, @@ -351,6 +315,34 @@ def forward_impl(self, hidden_states: torch.Tensor, return final_hidden_states + def _forward_ms_fused_moe_comp( + self, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + is_prefill: bool, + real_top_k, + enable_force_load_balance: bool = False, + ): + hidden_states = self.quant_method.apply( + layer=self, + x=hidden_states, + router_logits=router_logits, + top_k=real_top_k, + renormalize=self.renormalize, + use_grouped_topk=self.use_grouped_topk, + global_num_experts=self.global_num_experts, + expert_map=self.expert_map, + topk_group=self.topk_group, + num_expert_group=self.num_expert_group, + custom_routing_function=self.custom_routing_function, + scoring_func=self.scoring_func, + e_score_correction_bias=self.e_score_correction_bias, + is_prefill=is_prefill, + enable_force_load_balance=enable_force_load_balance, + ) + + return hidden_states + def transpose_weight(self, loaded_weight, expert_data, shard_dim): # Ensure training and inference weight shapes match during RL weight updates if ( @@ -437,8 +429,3 @@ def forward( router_logits=router_logits, ) return shared_out, fused_out - - -UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func -UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading -UnquantizedFusedMoEMethod.forward_oot = forward_oot diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 76b677a40a..06db2d9713 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -20,7 +20,7 @@ import torch import torch_npu -from vllm.config import get_current_vllm_config +from vllm.config import CompilationLevel, get_current_vllm_config from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, @@ -28,8 +28,6 @@ from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import \ FusedMoEConfig # isort: skip -from vllm.model_executor.layers.fused_moe.config import \ - FusedMoEParallelConfig # isort: skip from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) from vllm.model_executor.layers.quantization.base_config import \ @@ -55,19 +53,33 @@ def __init__(self, moe: FusedMoEConfig = None): super().__init__(moe=moe) vllm_config = get_current_vllm_config() - self.global_batch_size = vllm_config.scheduler_config.max_num_seqs - self.max_model_len = vllm_config.model_config.max_model_len - get_ascend_config() - - try: - device_group = get_mc2_group().device_group - # TODO: Try local_rank = ep_group.rank_in_group - local_rank = torch.distributed.get_rank(group=device_group) - backend = device_group._get_backend(torch.device("npu")) - self.moe_all_to_all_group_name = backend.get_hccl_comm_name( - local_rank) - except AttributeError: - self.moe_all_to_all_group_name = None + # self.global_batch_size = vllm_config.scheduler_config.max_num_seqs + # self.max_model_len = vllm_config.model_config.max_model_len + # get_ascend_config() + + # try: + # device_group = get_mc2_group().device_group + # # TODO: Try local_rank = ep_group.rank_in_group + # local_rank = torch.distributed.get_rank(group=device_group) + # backend = device_group._get_backend(torch.device("npu")) + # self.moe_all_to_all_group_name = backend.get_hccl_comm_name( + # local_rank) + # except AttributeError: + # self.moe_all_to_all_group_name = None + + # NOTE: Currently, this self.use_aclgraph is only used in + # UnquantizedFusedMoEMethod.forward_oot to decide whether to use in + # ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue. + # Once torch.randint_like is supported or removed, this flag can be removed. + vllm_config = get_current_vllm_config() + ascend_config = get_ascend_config() + if ascend_config.torchair_graph_config.enabled: + self.use_aclgraph = False + else: + self.use_aclgraph = (vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE and + not vllm_config.model_config.enforce_eager) + self.transpose = True def process_weights_after_loading(self, layer): super(UnquantizedFusedMoEMethod, @@ -195,35 +207,6 @@ def __init__( AscendFusedMoE.moe_counter += 1 self.moe_instance_id = AscendFusedMoE.moe_counter - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - vllm_config = get_current_vllm_config() - - self.moe_parallel_config = FusedMoEParallelConfig.make( - tp_size_=(tp_size if tp_size is not None else - get_tensor_model_parallel_world_size()), - dp_size_=(dp_size - if dp_size is not None else get_dp_group().world_size), - vllm_parallel_config=vllm_config.parallel_config) - - self.top_k = top_k - self.num_experts = num_experts - self.global_num_experts = num_experts - assert intermediate_size % self.tp_size == 0 - self.intermediate_size_per_partition = intermediate_size // self.tp_size - self.reduce_results = reduce_results - self.renormalize = renormalize - self.use_grouped_topk = use_grouped_topk - if self.use_grouped_topk: - assert num_expert_group is not None and topk_group is not None - self.num_expert_group = num_expert_group - self.topk_group = topk_group - self.custom_routing_function = custom_routing_function - self.scoring_func = scoring_func - self.e_score_correction_bias = e_score_correction_bias - self.expert_map = None - self.activation = activation self.log2phy = None self.global_redundant_expert_num = 0 @@ -256,49 +239,14 @@ def __init__( self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - if self.scoring_func != "softmax" and not self.use_grouped_topk: - raise ValueError("Only softmax scoring function is supported for " - "non-grouped topk.") - moe = FusedMoEConfig.make( - num_experts=self.global_num_experts, - experts_per_token=top_k, - hidden_dim=hidden_size, - num_local_experts=self.local_num_experts, - moe_parallel_config=self.moe_parallel_config, - # TODO (bnell): this needs to be fixed for quantized types. - in_dtype=params_dtype, - quant_config=quant_config) - - self.moe_config = moe - if quant_config is None: - self.quant_method = AscendUnquantizedFusedMoEMethod(moe) + self.quant_method = AscendUnquantizedFusedMoEMethod( + self.moe_config) else: self.quant_method = quant_config.get_quant_method(self, prefix) assert self.quant_method is not None - local_num_experts = torch.sum(self.expert_map != -1) \ - if self.expert_map is not None else num_experts - - moe_quant_params = { - "num_experts": local_num_experts, - "hidden_size": hidden_size, - "intermediate_size_per_partition": - self.intermediate_size_per_partition, - "params_dtype": params_dtype, - "weight_loader": self.weight_loader, - } - # need full intermediate size pre-sharding for WNA16 act order - if (self.quant_method.__class__.__name__ - in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")): - moe_quant_params["intermediate_size_full"] = intermediate_size - - self.ep_group = get_ep_group() - # NOTE: self.tp_group is not expert_tp_group - self.tp_group = get_tp_group().device_group - self.quant_method.create_weights(layer=self, **moe_quant_params) - self.moe_config.tp_group = get_tp_group() self.moe_config.dp_group = get_dp_group() self.moe_config.ep_group = get_ep_group() @@ -313,22 +261,6 @@ def __init__( self, method.__name__.lower(), method(moe_config=self.moe_config)) # type: ignore[abstract] - def naive_multicast(self, x: torch.Tensor, - cu_tokens_across_dp_cpu: torch.Tensor): - assert (len(x.shape) == 2) - buffer = torch.empty((cu_tokens_across_dp_cpu[-1], x.size(1)), - device=x.device, - dtype=x.dtype) - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - buffer[start:end, :].copy_(x) - for idx in range(self.dp_size): - start = 0 if idx == 0 else cu_tokens_across_dp_cpu[idx - 1] - end = cu_tokens_across_dp_cpu[idx] - get_dp_group().broadcast(buffer[start:end, :], idx) - return buffer - def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor, diff --git a/vllm_ascend/ops/moe/moe_comm_method.py b/vllm_ascend/ops/moe/moe_comm_method.py index 2194f4f720..cd594a6072 100644 --- a/vllm_ascend/ops/moe/moe_comm_method.py +++ b/vllm_ascend/ops/moe/moe_comm_method.py @@ -18,6 +18,7 @@ from typing import Any, Optional import torch +from vllm.config import get_current_vllm_config from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe import FusedMoEConfig @@ -26,7 +27,8 @@ FusedMoEPrepareAndFinalizeWithAllGather, FusedMoEPrepareAndFinalizeWithMC2, FusedMoEPrepareAndFinalizeWithNaiveMulticast) from vllm_ascend.ops.moe.moe_mlp import unified_apply_mlp -from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherWithAll2AllV, +from vllm_ascend.ops.moe.token_dispatcher import (TokenDispatcherMoge, + TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, TokenDispatcherWithMC2) @@ -35,6 +37,8 @@ class MoECommMethod(ABC): """Base class for MoE communication methods.""" def __init__(self, moe_config: FusedMoEConfig): + self.model_type = get_current_vllm_config( + ).model_config.hf_config.model_type self.moe_config = moe_config self.mc2_mask = None @@ -112,8 +116,8 @@ def fused_experts( apply_router_weight_on_input=apply_router_weight_on_input, with_quant=use_int8_w8a8 or use_int4_w4a8) - permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = \ - results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"] + permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales = \ + results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales") mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, w1=w1, @@ -125,6 +129,7 @@ def fused_experts( group_list_type=group_list_type, w1_scale_bias=w1_scale_bias, w2_scale_bias=w2_scale_bias, + topk_scales=topk_scales, with_quant=use_int8_w8a8 or use_int4_w4a8, fusion=use_int8_w8a8, @@ -166,94 +171,21 @@ class AllGatherCommImpl(MoECommMethod): """ def _get_token_dispatcher(self): - return TokenDispatcherWithAllGather( - top_k=self.moe_config.experts_per_token, - num_experts=self.moe_config.num_experts, - num_local_experts=self.moe_config.num_local_experts) + if self.model_type == "PanguProMoE": + return TokenDispatcherMoge( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) + else: + return TokenDispatcherWithAllGather( + top_k=self.moe_config.experts_per_token, + num_experts=self.moe_config.num_experts, + num_local_experts=self.moe_config.num_local_experts) def _get_fused_moe_prepare_finalize(self): return FusedMoEPrepareAndFinalizeWithAllGather(self.moe_config) -class NativeAllGatherCommImpl(AllGatherCommImpl): - """This implementation should be compatible with all scenarios. - - Note that this implementation purely consists of native PyTorch ops - and does not use any NPU-specific ops. So the performance may not be optimal. - But it is a good fallback for scenarios where NPU-specific ops are not available. - """ - - def permute( - self, - hidden_states: torch.Tensor, - topk_ids: torch.Tensor, - topk_weights: torch.Tensor, - expert_map: torch.Tensor, - num_experts: int, - apply_a8_quantization: bool, - ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]: - num_tokens = hidden_states.shape[0] - - # Generate token indices and flatten - token_indices = torch.arange(num_tokens, - device=hidden_states.device, - dtype=torch.int64) - token_indices = (token_indices.unsqueeze(1).expand( - -1, self.moe_config.experts_per_token).reshape(-1)) - - # Flatten token-to-expert mappings and map to local experts - weights_flat = topk_weights.view(-1) - experts_flat = topk_ids.view(-1) - local_experts_flat = (expert_map[experts_flat] - if expert_map is not None else experts_flat) - - # Filter valid token-expert pairs - mask = local_experts_flat != -1 - # FIXME: npu_grouped_matmul output random values at [num_valid_tokens:, ...] - # So we need to filter out invalid tokens by zeroing their weights. - # This is a workaround and should be removed after the issue is fixed - filtered_weights = torch.where(mask, weights_flat, - torch.zeros_like(weights_flat)).to( - topk_weights.dtype) - filtered_experts = torch.where( - mask, - local_experts_flat, - torch.full_like(local_experts_flat, num_experts), - ).to(topk_ids.dtype) - - # Sort by local expert IDs - sort_indices = torch.argsort(filtered_experts.view(torch.float32)) - self.sorted_token_indices = token_indices[sort_indices] - self.sorted_weights = filtered_weights[sort_indices] - - # Compute token counts with minlength of num_experts - # This is equivalent to but faster than: - # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] - token_counts = torch.zeros(num_experts + 1, - device=hidden_states.device, - dtype=torch.int64) - ones = torch.ones_like(filtered_experts, dtype=torch.int64) - token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) - expert_tokens = token_counts[:num_experts] - - # Rearrange hidden_states - permuted_hidden_states = hidden_states[self.sorted_token_indices] - - group_list_type = 1 # `count` mode - - return permuted_hidden_states, expert_tokens, None, group_list_type - - def unpermute(self, mlp_output: torch.Tensor, - hidden_states: torch.Tensor) -> None: - mlp_output = mlp_output * self.sorted_weights.unsqueeze(1) - - final_hidden_states = torch.zeros_like(hidden_states) - final_hidden_states.index_add_(0, self.sorted_token_indices, - mlp_output) - - hidden_states[:] = final_hidden_states - - class MC2CommImpl(MoECommMethod): """This implementation is for the scenarios listed below: 1. `enable_expert_parallel=True`. diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index c1a16d029b..04ad185790 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -377,14 +377,14 @@ def token_combine(self, # mypy: disable-error-code="override" -class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): +class TokenDispatcherMoge(MoETokenDispatcher): def __init__(self, **kwargs): super().__init__(**kwargs) self.apply_router_weight_on_input = False - self.local_ep = 1 - self.local_num_experts = self.num_experts // self.local_ep - self.local_num_group = self.top_k // self.local_ep + # self.local_ep = 1 + self.local_num_experts = self.num_experts // self.ep_size + self.local_num_group = self.top_k // self.ep_size self.bsz = None def token_dispatch(self, @@ -401,17 +401,6 @@ def token_dispatch(self, mc2_mask: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, with_quant: bool = False): - self.apply_router_weight_on_input = apply_router_weight_on_input - if self.apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - hidden_states = hidden_states * \ - topk_weights.to(hidden_states.dtype) - self.bsz, _ = hidden_states.shape flatten_topk_ids = topk_ids.view(-1) self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) @@ -445,7 +434,7 @@ def token_combine(self, unsorted_hidden_states = hidden_states.index_select( 0, unsorted_topk_ids) final_hidden_states = unsorted_hidden_states.reshape( - self.bsz, self.top_k // self.local_ep, -1).sum(1) + self.bsz, self.top_k // self.ep_size, -1).sum(1) return final_hidden_states diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 6124fcb2d1..04833d225b 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -37,7 +37,7 @@ from vllm_ascend.distributed.parallel_state import (get_mlp_tp_group, get_otp_group) -from vllm_ascend.ops.fused_moe import AscendUnquantizedFusedMoEMethod +from vllm_ascend.ops.common_fused_moe import AscendUnquantizedFusedMoEMethod from vllm_ascend.utils import (ASCEND_QUANTIZATION_METHOD, mlp_tp_enable, oproj_tp_enable) diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index a409bd3e01..5394185eb4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1687,6 +1687,7 @@ def _select_moe_comm_method(self, num_tokens: int, Returns: str: The selected MoE communication method, either "allgather", "mc2", or "alltoall". """ + model_type = self.vllm_config.model_config.hf_config.model_type soc_version = get_ascend_soc_version() quant_type = getattr(self.vllm_config.model_config.hf_config, 'moe_quantize', None) @@ -1710,6 +1711,9 @@ def _select_moe_comm_method(self, num_tokens: int, if moe_comm_method == "allgather" and with_prefill: moe_comm_method = "naivemulticast" + if model_type == "PanguProMoE": + moe_comm_method == "allgather" + if is_global_first_rank(): logger.debug(f"num_tokens: {num_tokens}, " f"moe_comm_method: {moe_comm_method}")