Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 156 additions & 5 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from typing import Callable, Optional

import torch
import torch.distributed as dist
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_tensor_model_parallel_world_size,
from vllm.distributed import (GroupCoordinator,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_dp_group
from vllm.model_executor.layers.fused_moe.layer import (
Expand Down Expand Up @@ -154,6 +156,143 @@ def fused_experts_with_mc2(
return hidden_states


# currently expert parallelism implemented with all2all
# is under-optimized.
def fused_experts_with_all2all(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])

num_tokens, _ = hidden_states.shape
num_experts = w1.shape[0]
device = hidden_states.device

if expert_map is not None:
global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)

global_expert_tokens = torch.bincount(expanded_expert_idx,
minlength=global_num_experts)
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
-1).sum(-1)

gather_sizes = torch.empty_like(scatter_sizes)
dist.all_to_all_single(gather_sizes,
scatter_sizes,
group=ep_group.device_group)
scatter_size_list = scatter_sizes.cpu().tolist()
gather_size_list = gather_sizes.cpu().tolist()

expanded_expert_idx = expanded_expert_idx % local_num_experts
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
scatter_size_list,
gather_size_list)
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
scatter_size_list,
gather_size_list)

sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)

expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
sorted_local_expert_idx, local_num_experts).to(torch.int64)

hidden_states = hidden_states[sorted_idx]
else:
row_idx_len = num_tokens * top_k
row_idx = torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=topk_weights.device).view(
top_k, -1).permute(1, 0).contiguous()
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)

expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts)
expert_tokens = expert_tokens.to(torch.int64)

w1 = w1.transpose(1, 2)
gate_up_out_list = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
)

# TODO: Remove this in the future.
hidden_states = torch.cat(gate_up_out_list, dim=0)
hidden_states = torch_npu.npu_swiglu(hidden_states)

w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
)

hidden_states = torch.cat(down_out_list, dim=0)

if expert_map is not None:
resorted_idx = torch.argsort(sorted_idx)
hidden_states = hidden_states[resorted_idx]
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
gather_size_list,
scatter_size_list)

final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states


def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand Down Expand Up @@ -494,7 +633,7 @@ def apply(
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill=False,
is_prefill: bool = False,
**kwargs,
):
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
Expand Down Expand Up @@ -536,14 +675,27 @@ def apply(
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
else:
elif get_ep_group().world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
else:
# The current implementation of deepseek moe splits hidden_states
# according to tp_size before they are feed into fused_moe module.
# Therefore, all2all is needed no matter how dp/tp is set so as to
# dispatch/combine tokens.
return fused_experts_with_all2all(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
ep_group=get_ep_group())


class AscendFusedMoE(FusedMoE):
Expand Down Expand Up @@ -721,8 +873,7 @@ def forward(self,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance,
dp_size=self.dp_size)
enable_force_load_balance=enable_force_load_balance)

if VLLM_ENABLE_MC2 and not is_prefill:
...
Expand Down
3 changes: 1 addition & 2 deletions vllm_ascend/quantization/quant_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,14 +323,13 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
dp_size: int = 1,
**kwargs,
) -> torch.Tensor:
return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, e_score_correction_bias,
is_prefill, enable_force_load_balance, dp_size)
is_prefill, enable_force_load_balance)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
Expand Down
7 changes: 5 additions & 2 deletions vllm_ascend/quantization/w8a8_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,6 @@ def apply(
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
dp_size: int = 1,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
Expand Down Expand Up @@ -635,7 +634,7 @@ def apply(
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
elif dp_size == 1:
elif self.ep_group.world_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
Expand All @@ -646,6 +645,10 @@ def apply(
top_k=top_k,
expert_map=expert_map)
else:
# The current implementation of deepseek moe splits hidden_states
# according to tp_size before they are feed into fused_moe module.
# Therefore, all2all is needed no matter how dp/tp is set so as to
# dispatch/combine tokens.
return fused_experts_with_all2all(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
Expand Down