Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions tests/e2e/multicard/test_qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.

Run `pytest tests/test_offline_inference.py`.
"""

from tests.e2e.conftest import VllmRunner


def test_models_distributed_Qwen3_MOE_TP2():
example_prompts = [
"Hello, my name is",
]
dtype = "half"
max_tokens = 5
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
dtype=dtype,
tensor_parallel_size=2,
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)


def test_models_distributed_Qwen3_MOE_TP2_WITH_EP():
example_prompts = [
"Hello, my name is",
]
dtype = "half"
max_tokens = 5
with VllmRunner(
"Qwen/Qwen3-30B-A3B",
dtype=dtype,
tensor_parallel_size=2,
enable_expert_parallel=True,
distributed_executor_backend="mp",
) as vllm_model:
vllm_model.generate_greedy(example_prompts, max_tokens)
103 changes: 96 additions & 7 deletions vllm_ascend/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,102 @@
# This file is a part of the vllm-ascend project.
from typing import Optional

import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig
from vllm.config import CacheConfig, CompilationLevel, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_tp_group)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention,
Qwen3MoeDecoderLayer,
Qwen3MoeForCausalLM,
Qwen3MoeMLP, Qwen3MoeModel)
Qwen3MoeMLP, Qwen3MoeModel,
Qwen3MoeSparseMoeBlock)
from vllm.model_executor.models.utils import (
extract_layer_index, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)

from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock
from vllm_ascend.platform import VllmConfig
from vllm_ascend.ops.fused_moe import AscendFusedMoE


class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock):

def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
nn.Module.__init__(self)
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")

self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)

self.experts = AscendFusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)

self.top_k = config.num_experts_per_tok

self.dp_size = get_dp_group().world_size

self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()

self.params_dtype = torch.get_default_dtype()

def forward(
self,
hidden_states,
attn_metadata=None,
):
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
enable_force_load_balance = get_forward_context().in_profile_run
is_prefill = get_forward_context().with_prefill

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
)

return hidden_states


class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer):
Expand All @@ -45,6 +122,7 @@ def __init__(
config: PretrainedConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
vllm_config: Optional[VllmConfig] = None,
prefix: str = "",
) -> None:

Expand Down Expand Up @@ -73,12 +151,22 @@ def __init__(
layer_idx = extract_layer_index(prefix)
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
use_aclgraph = (vllm_config is not None
and vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE
and not vllm_config.model_config.enforce_eager)
if (layer_idx not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_idx + 1) % config.decoder_sparse_step == 0):
self.mlp = AscendSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
if not use_aclgraph:
# FIXME: custom sparse moe block doesn't work with aclgraph.
self.mlp = CustomSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeSparseMoeBlock(config=config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
else:
self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
Expand Down Expand Up @@ -115,6 +203,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
config=config,
cache_config=cache_config,
quant_config=quant_config,
vllm_config=vllm_config,
prefix=prefix),
prefix=f"{prefix}.layers",
)
Expand Down
79 changes: 0 additions & 79 deletions vllm_ascend/ops/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import torch.distributed as dist
import torch_npu
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import get_current_vllm_config
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand All @@ -37,7 +35,6 @@
FusedMoEParallelConfig # isort: skip
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig

Expand Down Expand Up @@ -1546,79 +1543,3 @@ def _forward_ms_fused_moe_comp(
)

return hidden_states


class AscendSparseMoeBlock(nn.Module):

def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")

ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_moe = (
ascend_config.torchair_graph_config.enable_multistream_moe)

self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=None,
prefix=f"{prefix}.gate",
)

self.experts = AscendFusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
)

self.top_k = config.num_experts_per_tok

self.dp_size = get_dp_group().world_size

self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()

self.params_dtype = torch.get_default_dtype()

def forward(
self,
hidden_states: torch.Tensor,
attn_metadata: Optional[AttentionMetadata] = None,
) -> torch.Tensor:
if attn_metadata is None:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
enable_force_load_balance = get_forward_context().in_profile_run
is_prefill = get_forward_context().with_prefill

# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=self.top_k,
enable_force_load_balance=enable_force_load_balance,
shared_experts=None,
)

return hidden_states
Loading