diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 01da5bef35..7a3d29e6db 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -198,6 +198,7 @@ def fused_experts( num_experts = w1.shape[0] dtype = hidden_states.dtype device = hidden_states.device + topk_weights = topk_weights.to(dtype) # assert dtype in [torch.float32, torch.float16, torch.bfloat16 # ], "Only float32, float16, and bfloat16 are supported" @@ -615,32 +616,16 @@ def __init__( self.expert_map = None self.activation = activation - if self.ep_size > 1: - # Create a tensor of size num_experts filled with -1 - self.local_num_experts, self.expert_map = determine_expert_map( - self.ep_size, - get_ep_group().rank_in_group, self.global_num_experts) - if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): - self.tp_rank = get_etp_group().rank_in_group - self.ep_rank = get_ep_group().rank_in_group - else: - self.moe_parallel_config.tp_rank = get_etp_group( - ).rank_in_group - self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group - + # Create a tensor of size num_experts filled with -1 + self.local_num_experts, self.expert_map = determine_expert_map( + self.ep_size, + get_ep_group().rank_in_group, self.global_num_experts) + if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): + self.tp_rank = get_etp_group().rank_in_group + self.ep_rank = get_ep_group().rank_in_group else: - # Adjust TP size for DP attention - # haven't test its functionality yet, may remove in the future - if vllm_version_is("0.8.5") or vllm_version_is("0.8.5.post1"): - self.tp_rank = self.tp_size * self.dp_rank - self.ep_rank = 0 - self.tp_size = self.tp_size * self.dp_size - self.ep_size = 1 - else: - self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank - self.moe_parallel_config.ep_rank = 0 - self.moe_parallel_config.tp_size = self.tp_size * self.dp_size - self.moe_parallel_config.ep_size = 1 + self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group + self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group self.local_num_experts, self.expert_map = (self.global_num_experts, None) diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 5d2b442cf1..11d9b91270 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -342,6 +342,7 @@ def fused_experts(hidden_states: torch.Tensor, num_experts = w1.shape[0] dtype = hidden_states.dtype device = hidden_states.device + topk_weights = topk_weights.to(dtype) if expert_map is not None: # Generate token indices and flatten