Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm_ascend/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@
"VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE":
lambda: bool(int(os.getenv("VLLM_ASCEND_MODEL_EXECUTE_TIME_OBSERVE", '0'))
),
# MOE_ALL2ALL_BUFFER:
# 0: default, normal init.
# 1: enable moe_all2all_buffer.
"MOE_ALL2ALL_BUFFER":
lambda: bool(int(os.getenv("MOE_ALL2ALL_BUFFER", '0'))),
# VLLM_ASCEND_ACL_OP_INIT_MODE:
# 0: default, normal init.
# 1: delay init until launch aclops.
Expand Down
280 changes: 268 additions & 12 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py

from typing import Callable, Optional
from typing import Callable, List, Optional

import torch
import torch.distributed as dist
Expand All @@ -37,6 +37,71 @@

VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER


def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int,
max_row_per_ep_rank: int, num_tokens: int,
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
original_total_elements = num_tokens * top_k
device = topk_ids.device
original_dtype = topk_ids.dtype

if original_total_elements == 0:
output_len = ep_size * max_row_per_ep_rank
topk_ids_pad = torch.full((output_len, ),
expert_num,
dtype=original_dtype,
device=device)
unpad_indices = torch.full((original_total_elements, ),
-1,
dtype=torch.long,
device=device)
return topk_ids_pad, unpad_indices

experts_per_ep_rank_val = expert_num // ep_size
if experts_per_ep_rank_val == 0:
raise ValueError(
"expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
"Ensure expert_num >= ep_size.")

assigned_ep_rank = (topk_ids.float() /
experts_per_ep_rank_val).to(original_dtype)
indices_arange = torch.arange(topk_ids.shape[0], device=device)

is_new_segment = torch.cat((torch.tensor([True], device=device),
assigned_ep_rank[1:] != assigned_ep_rank[:-1]))
temp_start_markers = torch.full_like(indices_arange,
-1,
dtype=indices_arange.dtype)
temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
indices_in_rec_cond_list_for_all = cumsum_kept - 1
unpad_indices = torch.where(
is_kept_mask, indices_in_rec_cond_list_for_all,
torch.tensor(-1, device=device, dtype=torch.long))
output_len = ep_size * max_row_per_ep_rank
topk_ids_pad = torch.full((output_len, ),
expert_num,
dtype=original_dtype,
device=device)
if topk_ids.shape[0] > 0:
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
temp_pad_buffer = torch.full((output_len + 1, ),
expert_num,
dtype=original_dtype,
device=device)
output_len_tensor = torch.tensor(output_len,
dtype=torch.long,
device=device)
scatter_indices = torch.where(is_kept_mask, all_destination_indices,
output_len_tensor)
temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
topk_ids_pad = temp_pad_buffer[:output_len]
return topk_ids_pad, unpad_indices


def fused_experts_with_mc2(hidden_states: torch.Tensor,
Expand Down Expand Up @@ -146,8 +211,62 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
return hidden_states


# currently expert parallelism implemented with all2all
# is under-optimized.
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1) -> torch.Tensor:
"""
apply MLP: gate_up_proj -> swiglu -> down_proj

Args:
hidden_states_wrapper: wrapper of input hidden states with shape (num_tokens, hidden_size).
w1: expert weights1 with shape
(num_experts, hidden_size, intermediate_size * 2)
w2: expert weights2 with shape
(num_experts, intermediate_size, hidden_size)
group_list: number of tokens for each expert, follow cumsum mode, and
with shape (num_experts).
transpose_weight:
w1: (num_experts, intermediate_size * 2, hidden_size) ->
(num_experts, hidden_size, intermediate_size * 2)
w2: (num_experts, hidden_size, intermediate_size) ->
(num_experts, intermediate_size, hidden_size)

Returns:
hidden_states: output hidden states after MLP.
"""

assert len(hidden_states_wrapper) == 1
hidden_states = hidden_states_wrapper.pop()

w1 = w1.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)

hidden_states = torch.cat(hidden_states, dim=0)
hidden_states = torch_npu.npu_swiglu(hidden_states)

w2 = w2.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)

hidden_states = torch.cat(hidden_states, dim=0)
return hidden_states


def fused_experts_with_all2all(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand Down Expand Up @@ -283,6 +402,133 @@ def fused_experts_with_all2all(
return final_hidden_states


# currently expert parallelism implemented with all2all
# is under-optimized.
def fused_experts_with_all2all_buffer(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
max_model_len: int,
global_batch_size: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

num_tokens, _ = hidden_states.shape
device = hidden_states.device

global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
device=device).view(top_k,
-1).permute(1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)

max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
max_model_len // ep_group.world_size +
1) * top_k * 2
expert_idx_buffer_scatter, unpad_indices = process_topk_ids(
expanded_expert_idx, global_num_experts, ep_group.world_size,
max_row_per_ep_rank, num_tokens, top_k)
hidden_states_pad_idx = torch.zeros(
expert_idx_buffer_scatter.shape,
dtype=expert_idx_buffer_scatter.dtype,
device=expert_idx_buffer_scatter.device)
non_pad_len = torch.sum(
(expert_idx_buffer_scatter != global_num_experts).to(torch.int32))
hidden_states_pad_idx[
expert_idx_buffer_scatter != global_num_experts] = torch.arange(
non_pad_len,
dtype=expert_idx_buffer_scatter.dtype,
device=hidden_states.device)

hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
expert_idx_buffer_gather = torch.empty_like(
expert_idx_buffer_scatter,
dtype=expert_idx_buffer_scatter.dtype,
device=expert_idx_buffer_scatter.device)
hidden_states_buffer_gather = torch.empty_like(
hidden_states_buffer_scatter,
dtype=hidden_states_buffer_scatter.dtype,
device=hidden_states_buffer_scatter.device)
dist.all_to_all_single(expert_idx_buffer_gather,
expert_idx_buffer_scatter,
group=ep_group.device_group)
dist.all_to_all_single(hidden_states_buffer_gather,
hidden_states_buffer_scatter,
group=ep_group.device_group)
mask = expert_idx_buffer_gather != global_num_experts
local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
global_num_experts // ep_group.world_size)
hidden_states = hidden_states_buffer_gather[mask]
idx_type = local_expert_idx.dtype
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)

expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
sorted_local_expert_idx, local_num_experts).to(torch.int64)
hidden_states = hidden_states[sorted_idx]
group_list_type = 0

hidden_states_wrapper = [hidden_states]
del hidden_states

hidden_states = apply_mlp(hidden_states_wrapper,
w1,
w2,
expert_tokens,
group_list_type=group_list_type)

resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
hidden_states = hidden_states[resorted_idx]
hidden_states_scatter = torch.zeros(
(mask.shape[0], hidden_states.shape[1]),
dtype=hidden_states.dtype,
device=hidden_states.device)
hidden_states_scatter[mask] = hidden_states
hidden_states_gatter = torch.empty_like(
hidden_states_scatter,
dtype=hidden_states_scatter.dtype,
device=hidden_states_scatter.device)
dist.all_to_all_single(hidden_states_gatter,
hidden_states_scatter,
group=ep_group.device_group)
hidden_states_gatter = hidden_states_gatter[
expert_idx_buffer_scatter != global_num_experts]
if hidden_states_gatter.shape[0] != row_idx_len:
hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
dtype=hidden_states.dtype,
device=hidden_states.device)
hidden_states[unpad_indices != -1] = hidden_states_gatter
else:
# TODO: Reorder device memory 2 times here, replace the current
hidden_states = hidden_states_gatter
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)

if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states


def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand Down Expand Up @@ -585,6 +831,7 @@ def __init__(self, moe: MoEConfig = None):
self.ep_size = ep_group.world_size
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.local_batch_size = self.global_batch_size // self.ep_size
self.max_model_len = vllm_config.model_config.max_model_len

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
Expand Down Expand Up @@ -613,21 +860,22 @@ def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
use_grouped_topk: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = False,
enable_force_load_balance: bool = False,
**kwargs,
):
) -> torch.Tensor:

# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
Expand Down Expand Up @@ -683,11 +931,19 @@ def apply(
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
elif MOE_ALL2ALL_BUFFER:
return fused_experts_with_all2all_buffer(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
max_model_len=self.max_model_len,
global_batch_size=self.global_batch_size,
expert_map=expert_map,
ep_group=get_ep_group())
else:
# The current implementation of deepseek moe splits hidden_states
# according to tp_size before they are feed into fused_moe module.
# Therefore, all2all is needed no matter how dp/tp is set so as to
# dispatch/combine tokens.
return fused_experts_with_all2all(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
Expand Down