Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion tests/ut/ops/test_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import torch.nn as nn
import torch_npu
from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase

from vllm_ascend.ascend_forward_context import _get_fused_moe_state
from vllm_ascend.ops.fused_moe import (AscendFusedMoE,
Expand Down Expand Up @@ -59,6 +60,7 @@ def mock_dist_env(mocker: MockerFixture):
patch('vllm_ascend.ops.fused_moe.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \
patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \
patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \
patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce',
Expand Down Expand Up @@ -180,6 +182,18 @@ def __init__(self, shared_experts, num_tokens):
self.apply = MagicMock(return_value=(torch.randn(num_tokens, 32)))


class MockFusedMoEMethod(FusedMoEMethodBase):

def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
pass

def apply(self, hidden_states: torch.Tensor,
expert_weights: torch.Tensor) -> torch.Tensor:
pass


class TestAscendFusedMoe:

def test_init_no_quant(self, mock_dist_env, default_moe_config):
Expand Down Expand Up @@ -213,7 +227,7 @@ def test_init_no_quant(self, mock_dist_env, default_moe_config):

def test_init_with_quant(self, mock_dist_env, default_moe_config):
mock_quant_config = MagicMock()
mock_quant_method = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method

moe = AscendFusedMoE(**default_moe_config,
Expand Down
23 changes: 21 additions & 2 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,8 +1181,27 @@ def __init__(
):
# TODO: This could not initialize FusedMoE baseclass,
# fixme and make __init__() of AscendFusedMoE more clear
super(FusedMoE, self).__init__()

super().__init__(
num_experts=num_experts,
top_k=top_k,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
params_dtype=params_dtype,
reduce_results=reduce_results,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
num_expert_group=num_expert_group,
topk_group=topk_group,
quant_config=quant_config,
tp_size=tp_size,
ep_size=ep_size,
dp_size=dp_size,
prefix=prefix,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing apply_router_weight_on_input?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this reminder, I'll fix it soon

)
AscendFusedMoE.moe_counter += 1
self.moe_instance_id = AscendFusedMoE.moe_counter

Expand Down
Loading