Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/e2e/singlecard/ops/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
import torch
from vllm.model_executor.layers.activation import SiluAndMul

from vllm_ascend.ops.fused_moe import fused_experts, select_experts
from vllm_ascend.ops.fused_moe import fused_experts
from vllm_ascend.ops.layers.experts_selector import select_experts

NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
Expand Down Expand Up @@ -142,7 +143,7 @@ def test_select_experts(
dtype=torch.int32)
custom_routing_function.return_value = (mock_weights, mock_ids)

with patch("vllm_ascend.ops.fused_moe.native_grouped_topk"
with patch("vllm_ascend.ops.layers.experts_selector._native_grouped_topk"
) as mock_native_grouped_topk:
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)
Expand Down
26 changes: 26 additions & 0 deletions tests/ut/ops/test_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
AscendUnquantizedFusedMoEMethod)
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402

adapt_patch(True)
Expand Down Expand Up @@ -389,3 +390,28 @@ def test_apply_with_expert_map(self, moe_method, mock_dist_env,
assert result.shape == (16, 2)
else:
assert result.shape == x.shape


class TestExpertsSelector:

@pytest.mark.parametrize("global_num_experts", [[256], [128]])
def test_select_experts(self, mock_dist_env, mock_moe_env,
global_num_experts):

x = torch.randn(8, 2)
router_logits = torch.randn(8, 2)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=2,
use_grouped_topk=False,
renormalize=True,
topk_group=None,
num_expert_group=None,
custom_routing_function=None,
scoring_func="softmax",
e_score_correction_bias=None,
global_num_experts=global_num_experts)

assert topk_weights.shape == (8, 2)
assert topk_ids.shape == (8, 2)
25 changes: 13 additions & 12 deletions tests/ut/quantization/test_w8a8.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@

from tests.ut.base import TestBase
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.layers.experts_selector import (_native_grouped_topk,
select_experts)
from vllm_ascend.quantization.w8a8 import (AscendC8KVCacheMethod,
AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod,
fused_experts, fused_experts_310p,
native_grouped_topk,
quant_per_tensor, select_experts)
quant_per_tensor)


class TestQuantPerTensor(TestBase):
Expand Down Expand Up @@ -772,7 +773,7 @@ def test_grouped_topk(self, mock_topk):
self.assertEqual(ids.shape, (self.num_tokens, self.top_k))
self.assertEqual(ids.dtype, torch.int32)

@patch('vllm_ascend.quantization.w8a8.native_grouped_topk')
@patch('vllm_ascend.ops.layers.experts_selector._native_grouped_topk')
def test_grouped_topk_with_correction_bias(self, mock_grouped_topk):
"""Test grouped topk with expert score correction bias"""
mock_grouped_topk.return_value = torch.ones(self.num_tokens,
Expand Down Expand Up @@ -868,9 +869,9 @@ def test_basic_group_selection(self):

with patch('torch.topk',
return_value=(None, expected_topk_indices)) as mock_topk:
result = native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=2)
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=2)

mock_topk.assert_called_once()

Expand All @@ -885,9 +886,9 @@ def test_partial_group_selection(self):
expected_topk_indices = torch.tensor([[0], [1]])

with patch('torch.topk', return_value=(None, expected_topk_indices)):
result = native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=1)
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=2,
topk_group=1)

expected_result = torch.tensor(
[[0.1, 0.9, 0.2, 0.8, 0.0, 0.0, 0.0, 0.0],
Expand All @@ -900,7 +901,7 @@ def test_single_group(self):
expected_topk_indices = torch.tensor([[0], [0]])

with patch('torch.topk', return_value=(None, expected_topk_indices)):
result = native_grouped_topk(topk_weights=topk_weights,
num_expert_group=1,
topk_group=1)
result = _native_grouped_topk(topk_weights=topk_weights,
num_expert_group=1,
topk_group=1)
self.assertTrue(result.numel() > 0)
9 changes: 4 additions & 5 deletions vllm_ascend/ops/common_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
UnquantizedFusedMoEMethod

from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ops.fused_moe import (fused_experts_moge, select_experts,
unified_fused_experts)
from vllm_ascend.ops.fused_moe import fused_experts_moge, unified_fused_experts
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.utils import is_310p

original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
Expand Down Expand Up @@ -59,7 +59,7 @@ def forward_oot(
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
Expand All @@ -69,7 +69,6 @@ def forward_oot(
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:

topk_weights, topk_ids = select_experts(
global_num_experts=global_num_experts,
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
Expand All @@ -80,7 +79,7 @@ def forward_oot(
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
global_num_experts=global_num_experts)

if topk_ids.shape[1] < top_k or is_310p():
assert global_num_experts is not None
Expand Down
181 changes: 14 additions & 167 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from vllm_ascend.distributed.moe_comm_method import MoECommMethod
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
Expand Down Expand Up @@ -920,143 +921,6 @@ def fused_experts(
return final_hidden_states


def native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
topk_group: Optional[int],
):
topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group

num_token = topk_weights.shape[0]
grouped_weights = topk_weights.view(num_token, num_expert_group,
-1).max(dim=-1).values
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
k=topk_group,
dim=-1,
sorted=False)[1]
topk_group_mask = torch.zeros_like(grouped_weights)
topk_group_mask.scatter_(1, topk_group_indices, 1)
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)

return topk_weights


def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[torch.Tensor] = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.

Args:
hidden_states: Hidden states of shape (num_tokens, hidden_size).
router_logits: Router logits of shape (num_tokens, num_experts).
top_k: Number of experts to select.
use_grouped_topk: Whether to group experts before selecting top-k.
renormalize: Whether to renormalize the routing weights.
topk_group: Number of expert groups to select from.
num_expert_group: Number of experts in each group.
custom_routing_function: Custom routing function.
scoring_func: Scoring function to use.
e_score_correction_bias: Correction bias to apply to expert scores.

Returns:
topk_weights: Routing weights of shape (num_tokens, top_k).
topk_ids: Selected expert IDs of shape (num_tokens, top_k).

Raises:
ValueError: If an unsupported scoring function is provided.
"""

def _renormalize_topk_weights(
topk_weights: torch.Tensor,
renormalize: bool,
):
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1,
keepdim=True)
return topk_weights

if scoring_func == "softmax":
# NOTE: vLLM use dtype=torch.float here
if not use_grouped_topk and custom_routing_function is None:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
x=router_logits, finished=None, k=top_k)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids

topk_weights = router_logits.softmax(dim=-1)
elif scoring_func == "sigmoid":
topk_weights = router_logits.sigmoid()
else:
raise ValueError(f"Unsupported scoring function: {scoring_func}")

if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None

if e_score_correction_bias is not None:
# Store original scores before applying correction bias. We use biased
# scores for expert selection but original scores for routing weights
original_weights = topk_weights
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0)

# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids

if custom_routing_function is not None:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids

topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)

# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)

return topk_weights, topk_ids


class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):

def __init__(self, moe: FusedMoEConfig = None):
Expand Down Expand Up @@ -1111,36 +975,19 @@ def apply(
**kwargs,
) -> torch.Tensor:

is_deepseek_v3_r1 = global_num_experts == 256
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if is_deepseek_v3_r1:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k, # topk currently is 8
bias=e_score_correction_bias,
k_group=topk_group, # fix: 4
group_count=num_expert_group, # fix 8
group_select_mode=
1, # 0: the maximum in the group; 1: topk2.sum(fix)
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
norm_type=1, # 0: softmax; 1: sigmoid(fix)
# out_flag=False, # todo new api; should the third output be output
# y2_flag=False, # old api; should the third output be output
routed_scaling_factor=1,
eps=float(1e-20))
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts,
is_unquantized=True)

topk_weights = topk_weights.to(x.dtype)
# this is a naive implementation for experts load balance so as
Expand Down
Empty file.
Loading
Loading