Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/ut/ops/test_token_dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pytest_mock import MockerFixture

from tests.ut.base import PytestBase
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
from vllm_ascend.ops.moe_dispatcher.token_dispatcher_old import (
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
from vllm_ascend.utils import adapt_patch # noqa E402

Expand Down
17 changes: 16 additions & 1 deletion vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

import vllm_ascend.envs as envs
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.ops.moe_dispatcher.token_dispatcher_old import (
QuantizedTokenDispatcherWithAll2All,
UnquantizedTokenDispatcherWithAll2AllV)


class FusedMoEState(Enum):
Expand Down Expand Up @@ -77,6 +80,18 @@ def set_ascend_forward_context(
is_deepseek_v3_r1)
forward_context.fused_moe_state = fused_moe_state
forward_context.in_profile_run = in_profile_run

top_k = vllm_config.model_config.hf_config.num_experts_per_tok
num_experts = vllm_config.model_config.hf_config.n_routed_experts
quant_config = vllm_config.quant_config

need_param = {
"top_k": top_k, # Example value for top_k
"num_experts": num_experts # Example value for num_experts
}

token_dispatcher = UnquantizedTokenDispatcherWithAll2AllV(need_param)
forward_context.token_dispatcher = token_dispatcher

# NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing
Expand Down Expand Up @@ -111,4 +126,4 @@ def set_ascend_forward_context(
try:
yield
finally:
pass
pass
106 changes: 57 additions & 49 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,8 @@
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
from vllm_ascend.ops.moe_dispatcher.token_dispatcher_old import (
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
Expand Down Expand Up @@ -205,9 +204,11 @@ def fused_experts_with_mc2(
group_list_type=1,
group_type=0,
group_list=group_list,
)[0]
)

gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
# TODO: Remove this in the future.
gate_up_out = torch.cat(gate_up_out_list, dim=0)
gate_up_out = torch_npu.npu_swiglu(gate_up_out)

w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul(
Expand All @@ -217,7 +218,9 @@ def fused_experts_with_mc2(
group_list_type=1,
group_type=0,
group_list=group_list,
)[0]
)

down_out_list = torch.cat(down_out_list, dim=0)

# moeCombine
kwargs_mc2 = {
Expand Down Expand Up @@ -308,8 +311,9 @@ def apply_mlp(
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
)

hidden_states = torch.cat(hidden_states, dim=0)
hidden_states = torch_npu.npu_swiglu(hidden_states)

w2 = w2.transpose(1, 2)
Expand All @@ -320,8 +324,9 @@ def apply_mlp(
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
)[0]
)

hidden_states = torch.cat(hidden_states, dim=0)
return hidden_states


Expand Down Expand Up @@ -411,19 +416,23 @@ def fused_experts_with_all2all(
group_list_type=0,
group_type=0,
group_list=expert_tokens,
)[0]
)

hidden_states = torch_npu.npu_swiglu(gate_up_out_list)
# TODO: Remove this in the future.
hidden_states = torch.cat(gate_up_out_list, dim=0)
hidden_states = torch_npu.npu_swiglu(hidden_states)

w2 = w2.transpose(1, 2)
hidden_states = torch_npu.npu_grouped_matmul(
down_out_list = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
)[0]
)

hidden_states = torch.cat(down_out_list, dim=0)

if expert_map is not None:
resorted_idx = torch.argsort(sorted_idx)
Expand Down Expand Up @@ -691,6 +700,24 @@ def fused_experts_with_all2allv(
return output


def universal_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
max_num_tokens: Optional[int] = None,
):
token_dispatcher = get_forward_context().token_dispatcher
_, dispatched_input, tokens_per_expert = token_dispatcher.token_permutation(
hidden_states, topk_weights, topk_ids, expert_map)
expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert)
final_hidden_states = token_dispatcher.token_unpermutation(expert_output)
return final_hidden_states


def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
Expand Down Expand Up @@ -813,9 +840,11 @@ def fused_experts(
group_list_type=0,
group_type=0,
group_list=expert_tokens,
)[0]
)

gate_up_out = torch_npu.npu_swiglu(gate_up_out_list)
# TODO: Remove this in the future.
gate_up_out = torch.cat(gate_up_out_list, dim=0)
gate_up_out = torch_npu.npu_swiglu(gate_up_out)

w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul(
Expand All @@ -825,7 +854,9 @@ def fused_experts(
group_list_type=0,
group_type=0,
group_list=expert_tokens,
)[0]
)

down_out_list = torch.cat(down_out_list, dim=0)

if expert_map is not None:
weighted_down_out = down_out_list * sorted_weights.unsqueeze(1)
Expand Down Expand Up @@ -1094,6 +1125,13 @@ def apply(

fused_moe_state = get_forward_context().fused_moe_state

return universal_fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map)

if fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2(
hidden_states=x,
Expand Down Expand Up @@ -1181,27 +1219,8 @@ def __init__(
):
# TODO: This could not initialize FusedMoE baseclass,
# fixme and make __init__() of AscendFusedMoE more clear
super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=reduce_results,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
quant_config=quant_config,
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
prefix=prefix,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
)
super(FusedMoE, self).__init__()

AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter

Expand Down Expand Up @@ -1272,7 +1291,7 @@ def __init__(
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
self.moe = FusedMoEConfig.make(
moe = FusedMoEConfig.make(
num_experts=self.global_num_experts,
experts_per_token=top_k,
hidden_dim=hidden_size,
Expand All @@ -1283,7 +1302,7 @@ def __init__(
quant_config=quant_config)

if quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(self.moe)
self.quant_method = AscendUnquantizedFusedMoEMethod(moe)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)

Expand Down Expand Up @@ -1353,8 +1372,7 @@ def forward(self,
top_k: Optional[int] = None,
shared_experts: Optional[Any] = None,
gate=None,
replace_allreduce: bool = False,
_metadata_for_padding: Optional[MetadataForPadding] = None):
replace_allreduce: bool = False):

assert self.quant_method is not None

Expand Down Expand Up @@ -1388,17 +1406,7 @@ def forward(self,
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states)

mc2_mask = forward_context.mc2_mask

enable_sp = _metadata_for_padding is not None and _metadata_for_padding.not_dummy_and_is_prefill
tp_size = get_tensor_model_parallel_world_size()
if enable_sp:
tp_rank = get_tensor_model_parallel_rank()
mc2_mask_sp = _metadata_for_padding.mc2_mask if _metadata_for_padding is not None else forward_context.mc2_mask
chunk_mc2_mask = torch.tensor_split(mc2_mask_sp, tp_size, dim=0)
mc2_mask = chunk_mc2_mask[tp_rank]
replace_allreduce = True

if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
Expand Down
Loading
Loading