From ba2e08fdf543bc376d3187edabc9f18190d21221 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Tue, 27 May 2025 23:29:06 +0800 Subject: [PATCH 1/2] [perf] Improve Prefill Performance by Optimizing Alltoall Communication Signed-off-by: SlightwindSec --- vllm_ascend/envs.py | 2 + vllm_ascend/ops/fused_moe.py | 1 + vllm_ascend/quantization/w8a8_dynamic.py | 204 +++++++++++++++++++++++ 3 files changed, 207 insertions(+) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 8e1cc1c162..39d81f4ee4 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -36,6 +36,8 @@ lambda: bool(int(os.getenv("COMPILE_CUSTOM_KERNELS", "1"))), "VLLM_ENABLE_MC2": lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))), + "VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER": + lambda: bool(int(os.getenv("VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER", '0'))), "USING_LCCL_COM": lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))), "SOC_VERSION": diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index bc3b86b653..f2214a9f6d 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -767,6 +767,7 @@ def __init__( self.e_score_correction_bias = e_score_correction_bias self.expert_map = None self.activation = activation + self.max_model_len = vllm_config.model_config.max_model_len if self.ep_size > 1: # Create a tensor of size num_experts filled with -1 diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 0f54b012f1..0b00ae2e29 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -27,6 +27,71 @@ from vllm_ascend.ops.fused_moe import select_experts VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 +VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER: bool = envs_ascend.VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER + + +def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, + max_row_per_ep_rank: int, num_tokens: int, + top_k: int) -> tuple[torch.Tensor, torch.Tensor]: + original_total_elements = num_tokens * top_k + device = topk_ids.device + original_dtype = topk_ids.dtype + + if original_total_elements == 0: + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + unpad_indices = torch.full((original_total_elements, ), + -1, + dtype=torch.long, + device=device) + return topk_ids_pad, unpad_indices + + experts_per_ep_rank_val = expert_num // ep_size + if experts_per_ep_rank_val == 0: + raise ValueError( + "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " + "Ensure expert_num >= ep_size.") + + assigned_ep_rank = (topk_ids.float() / + experts_per_ep_rank_val).to(original_dtype) + indices_arange = torch.arange(topk_ids.shape[0], device=device) + + is_new_segment = torch.cat((torch.tensor([True], device=device), + assigned_ep_rank[1:] != assigned_ep_rank[:-1])) + temp_start_markers = torch.full_like(indices_arange, + -1, + dtype=indices_arange.dtype) + temp_start_markers[is_new_segment] = indices_arange[is_new_segment] + start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0] + token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token + is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank + cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) + indices_in_rec_cond_list_for_all = cumsum_kept - 1 + unpad_indices = torch.where( + is_kept_mask, indices_in_rec_cond_list_for_all, + torch.tensor(-1, device=device, dtype=torch.long)) + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + if topk_ids.shape[0] > 0: + all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx + temp_pad_buffer = torch.full((output_len + 1, ), + expert_num, + dtype=original_dtype, + device=device) + output_len_tensor = torch.tensor(output_len, + dtype=torch.long, + device=device) + scatter_indices = torch.where(is_kept_mask, all_destination_indices, + output_len_tensor) + temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) + topk_ids_pad = temp_pad_buffer[:output_len] + return topk_ids_pad, unpad_indices def apply_mlp(hidden_states_wrapper: List[torch.Tensor], @@ -325,6 +390,132 @@ def fused_experts_with_all2all( return final_hidden_states +def fused_experts_with_all2all_with_fixed_buffer( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + max_model_len: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, +): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + device = hidden_states.device + + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_group.world_size + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, + device=device).view(top_k, + -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + max_row_per_ep_rank = (max_model_len // ep_group.world_size + + 1) * top_k * 2 + expert_idx_buffer_scatter, unpad_indices = process_topk_ids( + expanded_expert_idx, global_num_experts, ep_group.world_size, + max_row_per_ep_rank, num_tokens, top_k) + hidden_states_pad_idx = torch.zeros( + expert_idx_buffer_scatter.shape, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + non_pad_len = torch.sum( + (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) + hidden_states_pad_idx[ + expert_idx_buffer_scatter != global_num_experts] = torch.arange( + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) + + hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] + expert_idx_buffer_gather = torch.empty_like( + expert_idx_buffer_scatter, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + hidden_states_buffer_gather = torch.empty_like( + hidden_states_buffer_scatter, + dtype=hidden_states_buffer_scatter.dtype, + device=hidden_states_buffer_scatter.device) + dist.all_to_all_single(expert_idx_buffer_gather, + expert_idx_buffer_scatter, + group=ep_group.device_group) + dist.all_to_all_single(hidden_states_buffer_gather, + hidden_states_buffer_scatter, + group=ep_group.device_group) + mask = expert_idx_buffer_gather != global_num_experts + local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( + global_num_experts // ep_group.world_size) + hidden_states = hidden_states_buffer_gather[mask] + idx_type = local_expert_idx.dtype + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) + sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + sorted_local_expert_idx, local_num_experts).to(torch.int64) + hidden_states = hidden_states[sorted_idx] + group_list_type = 0 + + hidden_states_wrapper = [hidden_states] + del hidden_states + + hidden_states = apply_mlp(hidden_states_wrapper, + w1, + w1_scale, + w2, + w2_scale, + expert_tokens, + group_list_type=group_list_type) + + resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype) + hidden_states = hidden_states[resorted_idx] + hidden_states_scatter = torch.zeros( + (mask.shape[0], hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states_scatter[mask] = hidden_states + hidden_states_gatter = torch.empty_like( + hidden_states_scatter, + dtype=hidden_states_scatter.dtype, + device=hidden_states_scatter.device) + dist.all_to_all_single(hidden_states_gatter, + hidden_states_scatter, + group=ep_group.device_group) + hidden_states_gatter = hidden_states_gatter[ + expert_idx_buffer_scatter != global_num_experts] + if hidden_states_gatter.shape[0] != row_idx_len: + hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states[unpad_indices != -1] = hidden_states_gatter + else: + hidden_states = hidden_states_gatter + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + def fused_experts(hidden_states: torch.Tensor, w1: torch.Tensor, w1_scale: torch.Tensor, @@ -644,6 +835,19 @@ def apply( topk_ids=topk_ids, top_k=top_k, expert_map=expert_map) + elif VLLM_ENABLE_FIXED_ALL_TO_ALL_BUFFER and expert_map is not None: + return fused_experts_with_all2all_with_fixed_buffer( + hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + max_model_len=layer.max_model_len, + expert_map=expert_map, + ep_group=self.ep_group) else: # The current implementation of deepseek moe splits hidden_states # according to tp_size before they are feed into fused_moe module. From 7998d8fd37d7fa4e2c83b2f0c0eb2bc39ec620a9 Mon Sep 17 00:00:00 2001 From: SlightwindSec Date: Wed, 28 May 2025 21:29:33 +0800 Subject: [PATCH 2/2] fix Signed-off-by: SlightwindSec --- vllm_ascend/ops/fused_moe.py | 2 ++ vllm_ascend/quantization/w8a8_dynamic.py | 5 ++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index f2214a9f6d..1c54685c49 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -768,6 +768,8 @@ def __init__( self.expert_map = None self.activation = activation self.max_model_len = vllm_config.model_config.max_model_len + self.global_batch_size = vllm_config.scheduler_config.max_num_seqs * ( + dp_size if dp_size is not None else get_dp_group().world_size) if self.ep_size > 1: # Create a tensor of size num_experts filled with -1 diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 0b00ae2e29..cc733910cd 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -400,6 +400,7 @@ def fused_experts_with_all2all_with_fixed_buffer( topk_ids: torch.Tensor, top_k: int, max_model_len: int, + global_batch_size: int, expert_map: torch.Tensor = None, ep_group: GroupCoordinator = None, ): @@ -422,7 +423,8 @@ def fused_experts_with_all2all_with_fixed_buffer( expert_idx=topk_ids, active_num=num_tokens) - max_row_per_ep_rank = (max_model_len // ep_group.world_size + + max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) * + max_model_len // ep_group.world_size + 1) * top_k * 2 expert_idx_buffer_scatter, unpad_indices = process_topk_ids( expanded_expert_idx, global_num_experts, ep_group.world_size, @@ -846,6 +848,7 @@ def apply( topk_ids=topk_ids, top_k=top_k, max_model_len=layer.max_model_len, + global_batch_size=layer.global_batch_size, expert_map=expert_map, ep_group=self.ep_group) else: