Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 97 additions & 1 deletion tests/singlecard/ops/test_fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@
# here to make the test pass.
import vllm_ascend.patch.worker.patch_common.patch_utils # type: ignore[import] # isort: skip # noqa

from unittest.mock import MagicMock, patch

import pytest
import torch
from vllm.model_executor.layers.activation import SiluAndMul

from vllm_ascend.ops.fused_moe import fused_experts
from vllm_ascend.ops.fused_moe import fused_experts, select_experts

NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4]
Expand Down Expand Up @@ -98,3 +100,97 @@ def test_fused_experts(
# TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem
torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1)
torch.npu.empty_cache()


@pytest.mark.parametrize("m", [1, 33, 64])
@pytest.mark.parametrize("n", [128, 1024, 2048])
@pytest.mark.parametrize("e", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("use_grouped_topk", [True, False])
@pytest.mark.parametrize("renormalize", [True, False])
@pytest.mark.parametrize("with_e_correction", [True, False])
@pytest.mark.parametrize("custom_routing", [True, False])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("device", DEVICE)
def test_select_experts(
m: int,
n: int,
e: int,
topk: int,
scoring_func: str,
use_grouped_topk: bool,
renormalize: bool,
with_e_correction: bool,
custom_routing: bool,
dtype: torch.dtype,
device: str,
):
topk_group = 4 if use_grouped_topk else None
num_expert_group = e // 4 if use_grouped_topk else None

hidden_states = torch.randn(m, n, device=device, dtype=dtype)
router_logits = torch.randn(m, e, device=device, dtype=dtype)

e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype)
if with_e_correction else None)

custom_routing_function = None
if custom_routing:
custom_routing_function = MagicMock()
mock_weights = torch.randn(m, topk, device=device, dtype=dtype)
mock_ids = torch.randint(0,
e, (m, topk),
device=device,
dtype=torch.int32)
custom_routing_function.return_value = (mock_weights, mock_ids)

with patch("vllm_ascend.ops.fused_moe.native_grouped_topk"
) as mock_native_grouped_topk:
mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(
x)

topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=topk,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)

if use_grouped_topk:
mock_native_grouped_topk.assert_called_once()
else:
mock_native_grouped_topk.assert_not_called()

assert topk_weights.shape == (m, topk)
assert topk_ids.shape == (m, topk)
assert topk_ids.dtype == torch.int32


@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_invalid_scoring_func(device: str):
with pytest.raises(ValueError,
match="Unsupported scoring function: invalid"):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 8, device=device),
top_k=2,
use_grouped_topk=False,
renormalize=False,
scoring_func="invalid")


@pytest.mark.parametrize("device", DEVICE)
def test_select_experts_missing_group_params(device: str):
with pytest.raises(AssertionError):
select_experts(hidden_states=torch.randn(1, 128, device=device),
router_logits=torch.randn(1, 64, device=device),
top_k=2,
use_grouped_topk=True,
renormalize=False,
scoring_func="softmax")
31 changes: 24 additions & 7 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,8 +877,23 @@ def select_experts(
ValueError: If an unsupported scoring function is provided.
"""

def _renormalize_topk_weights(
topk_weights: torch.Tensor,
renormalize: bool,
):
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1,
keepdim=True)
return topk_weights

if scoring_func == "softmax":
# NOTE: vLLM use dtype=torch.float here
if not use_grouped_topk and custom_routing_function is None:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k_softmax(
x=router_logits, finished=None, k=top_k)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids

topk_weights = router_logits.softmax(dim=-1)
elif scoring_func == "sigmoid":
topk_weights = router_logits.sigmoid()
Expand Down Expand Up @@ -912,10 +927,11 @@ def select_experts(
k=top_k,
dim=-1,
sorted=False)
elif custom_routing_function is None:
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)
else:
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids

if custom_routing_function is not None:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
Expand All @@ -926,11 +942,12 @@ def select_experts(
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids

topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1)
topk_weights = topk_weights.to(hidden_states.dtype)

# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)

if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)

return topk_weights, topk_ids

Expand Down