diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 2d11e3d3a8..23557f1b84 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -112,6 +112,11 @@ "VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE": lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0')) ), + # MOE_ALL2ALL_BUFFER: + # 0: default, normal init. + # 1: enable moe_all2all_buffer. + "MOE_ALL2ALL_BUFFER": + lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))), # VLLM_ASCEND_ACL_OP_INIT_MODE: # 0: default, normal init. # 1: delay init until launch aclops. diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index 6036e60cb5..fd06d90e21 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -15,7 +15,7 @@ # This file is a part of the vllm-ascend project. # Adapted from vllm/tests/kernels/test_moe.py -from typing import Callable, Optional +from typing import Callable, List, Optional import torch import torch.distributed as dist @@ -37,6 +37,71 @@ VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM +MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER + + +def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, + max_row_per_ep_rank: int, num_tokens: int, + top_k: int) -> tuple[torch.Tensor, torch.Tensor]: + original_total_elements = num_tokens * top_k + device = topk_ids.device + original_dtype = topk_ids.dtype + + if original_total_elements == 0: + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + unpad_indices = torch.full((original_total_elements, ), + -1, + dtype=torch.long, + device=device) + return topk_ids_pad, unpad_indices + + experts_per_ep_rank_val = expert_num // ep_size + if experts_per_ep_rank_val == 0: + raise ValueError( + "expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. " + "Ensure expert_num >= ep_size.") + + assigned_ep_rank = (topk_ids.float() / + experts_per_ep_rank_val).to(original_dtype) + indices_arange = torch.arange(topk_ids.shape[0], device=device) + + is_new_segment = torch.cat((torch.tensor([True], device=device), + assigned_ep_rank[1:] != assigned_ep_rank[:-1])) + temp_start_markers = torch.full_like(indices_arange, + -1, + dtype=indices_arange.dtype) + temp_start_markers[is_new_segment] = indices_arange[is_new_segment] + start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0] + token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token + is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank + cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long) + indices_in_rec_cond_list_for_all = cumsum_kept - 1 + unpad_indices = torch.where( + is_kept_mask, indices_in_rec_cond_list_for_all, + torch.tensor(-1, device=device, dtype=torch.long)) + output_len = ep_size * max_row_per_ep_rank + topk_ids_pad = torch.full((output_len, ), + expert_num, + dtype=original_dtype, + device=device) + if topk_ids.shape[0] > 0: + all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx + temp_pad_buffer = torch.full((output_len + 1, ), + expert_num, + dtype=original_dtype, + device=device) + output_len_tensor = torch.tensor(output_len, + dtype=torch.long, + device=device) + scatter_indices = torch.where(is_kept_mask, all_destination_indices, + output_len_tensor) + temp_pad_buffer.scatter_(0, scatter_indices, topk_ids) + topk_ids_pad = temp_pad_buffer[:output_len] + return topk_ids_pad, unpad_indices def fused_experts_with_mc2(hidden_states: torch.Tensor, @@ -146,8 +211,62 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor, return hidden_states -# currently expert parallelism implemented with all2all -# is under-optimized. +def apply_mlp(hidden_states_wrapper: List[torch.Tensor], + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1) -> torch.Tensor: + """ + apply MLP: gate_up_proj -> swiglu -> down_proj + + Args: + hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size). + w1: expert weights1 with shape + (num_experts, hidden_size, intermediate_size * 2) + w2: expert weights2 with shape + (num_experts, intermediate_size, hidden_size) + group_list: number of tokens for each expert, follow cumsum mode, and + with shape (num_experts). + transpose_weight: + w1: (num_experts, intermediate_size * 2, hidden_size) -> + (num_experts, hidden_size, intermediate_size * 2) + w2: (num_experts, hidden_size, intermediate_size) -> + (num_experts, intermediate_size, hidden_size) + + Returns: + hidden_states: output hidden states after MLP. + """ + + assert len(hidden_states_wrapper) == 1 + hidden_states = hidden_states_wrapper.pop() + + w1 = w1.transpose(1, 2) + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + ) + + hidden_states = torch.cat(hidden_states, dim=0) + hidden_states = torch_npu.npu_swiglu(hidden_states) + + w2 = w2.transpose(1, 2) + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + ) + + hidden_states = torch.cat(hidden_states, dim=0) + return hidden_states + + def fused_experts_with_all2all( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -283,6 +402,133 @@ def fused_experts_with_all2all( return final_hidden_states +# currently expert parallelism implemented with all2all +# is under-optimized. +def fused_experts_with_all2all_buffer( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + top_k: int, + max_model_len: int, + global_batch_size: int, + expert_map: torch.Tensor = None, + ep_group: GroupCoordinator = None, +): + original_shape = hidden_states.shape + if len(original_shape) == 3: + hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) + + num_tokens, _ = hidden_states.shape + device = hidden_states.device + + global_num_experts = len(expert_map) + local_num_experts = global_num_experts // ep_group.world_size + row_idx_len = num_tokens * top_k + row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32, + device=device).view(top_k, + -1).permute(1, 0).contiguous()) + hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( + hidden_states, + row_idx=row_idx, + expert_idx=topk_ids, + active_num=num_tokens) + + max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) * + max_model_len // ep_group.world_size + + 1) * top_k * 2 + expert_idx_buffer_scatter, unpad_indices = process_topk_ids( + expanded_expert_idx, global_num_experts, ep_group.world_size, + max_row_per_ep_rank, num_tokens, top_k) + hidden_states_pad_idx = torch.zeros( + expert_idx_buffer_scatter.shape, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + non_pad_len = torch.sum( + (expert_idx_buffer_scatter != global_num_experts).to(torch.int32)) + hidden_states_pad_idx[ + expert_idx_buffer_scatter != global_num_experts] = torch.arange( + non_pad_len, + dtype=expert_idx_buffer_scatter.dtype, + device=hidden_states.device) + + hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] + expert_idx_buffer_gather = torch.empty_like( + expert_idx_buffer_scatter, + dtype=expert_idx_buffer_scatter.dtype, + device=expert_idx_buffer_scatter.device) + hidden_states_buffer_gather = torch.empty_like( + hidden_states_buffer_scatter, + dtype=hidden_states_buffer_scatter.dtype, + device=hidden_states_buffer_scatter.device) + dist.all_to_all_single(expert_idx_buffer_gather, + expert_idx_buffer_scatter, + group=ep_group.device_group) + dist.all_to_all_single(hidden_states_buffer_gather, + hidden_states_buffer_scatter, + group=ep_group.device_group) + mask = expert_idx_buffer_gather != global_num_experts + local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( + global_num_experts // ep_group.world_size) + hidden_states = hidden_states_buffer_gather[mask] + idx_type = local_expert_idx.dtype + sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) + sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) + + expert_tokens = torch_npu.npu_moe_compute_expert_tokens( + sorted_local_expert_idx, local_num_experts).to(torch.int64) + hidden_states = hidden_states[sorted_idx] + group_list_type = 0 + + hidden_states_wrapper = [hidden_states] + del hidden_states + + hidden_states = apply_mlp(hidden_states_wrapper, + w1, + w2, + expert_tokens, + group_list_type=group_list_type) + + resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype) + hidden_states = hidden_states[resorted_idx] + hidden_states_scatter = torch.zeros( + (mask.shape[0], hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states_scatter[mask] = hidden_states + hidden_states_gatter = torch.empty_like( + hidden_states_scatter, + dtype=hidden_states_scatter.dtype, + device=hidden_states_scatter.device) + dist.all_to_all_single(hidden_states_gatter, + hidden_states_scatter, + group=ep_group.device_group) + hidden_states_gatter = hidden_states_gatter[ + expert_idx_buffer_scatter != global_num_experts] + if hidden_states_gatter.shape[0] != row_idx_len: + hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), + dtype=hidden_states.dtype, + device=hidden_states.device) + hidden_states[unpad_indices != -1] = hidden_states_gatter + else: + # TODO: Reorder device memory 2 times here, replace the current + hidden_states = hidden_states_gatter + final_hidden_states = torch_npu.npu_moe_finalize_routing( + hidden_states, + skip1=None, + skip2=None, + bias=None, + scales=topk_weights, + expanded_src_to_dst_row=expanded_row_idx, + export_for_source_row=topk_ids, + ) + + if len(original_shape) == 3: + final_hidden_states = final_hidden_states.view(original_shape) + return final_hidden_states + + def fused_experts( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -585,6 +831,7 @@ def __init__(self, moe: MoEConfig = None): self.ep_size = ep_group.world_size self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.local_batch_size = self.global_batch_size // self.ep_size + self.max_model_len = vllm_config.model_config.max_model_len ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled @@ -613,21 +860,22 @@ def apply( self, layer: torch.nn.Module, x: torch.Tensor, - use_grouped_topk: bool, - top_k: int, router_logits: torch.Tensor, + top_k: int, renormalize: bool, - topk_group: Optional[int] = None, - num_expert_group: Optional[int] = None, + use_grouped_topk: bool = False, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", e_score_correction_bias: Optional[torch.Tensor] = None, is_prefill: bool = False, enable_force_load_balance: bool = False, **kwargs, - ): + ) -> torch.Tensor: + # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern if global_num_experts == 256: topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( @@ -683,11 +931,19 @@ def apply( topk_ids=topk_ids, top_k=top_k, expert_map=expert_map) + elif MOE_ALL2ALL_BUFFER: + return fused_experts_with_all2all_buffer( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + top_k=top_k, + max_model_len=self.max_model_len, + global_batch_size=self.global_batch_size, + expert_map=expert_map, + ep_group=get_ep_group()) else: - # The current implementation of deepseek moe splits hidden_states - # according to tp_size before they are feed into fused_moe module. - # Therefore, all2all is needed no matter how dp/tp is set so as to - # dispatch/combine tokens. return fused_experts_with_all2all(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight,