diff --git a/tests/e2e/multicard/test_qwen3_moe.py b/tests/e2e/multicard/test_qwen3_moe.py new file mode 100644 index 0000000000..ccc31d4c1d --- /dev/null +++ b/tests/e2e/multicard/test_qwen3_moe.py @@ -0,0 +1,55 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py +# +"""Compare the short outputs of HF and vLLM when using greedy sampling. + +Run `pytest tests/test_offline_inference.py`. +""" + +from tests.e2e.conftest import VllmRunner + + +def test_models_distributed_Qwen3_MOE_TP2(): + example_prompts = [ + "Hello, my name is", + ] + dtype = "half" + max_tokens = 5 + with VllmRunner( + "Qwen/Qwen3-30B-A3B", + dtype=dtype, + tensor_parallel_size=2, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) + + +def test_models_distributed_Qwen3_MOE_TP2_WITH_EP(): + example_prompts = [ + "Hello, my name is", + ] + dtype = "half" + max_tokens = 5 + with VllmRunner( + "Qwen/Qwen3-30B-A3B", + dtype=dtype, + tensor_parallel_size=2, + enable_expert_parallel=True, + distributed_executor_backend="mp", + ) as vllm_model: + vllm_model.generate_greedy(example_prompts, max_tokens) diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py index c133acc66a..f29c2a5157 100644 --- a/vllm_ascend/models/qwen3_moe.py +++ b/vllm_ascend/models/qwen3_moe.py @@ -17,11 +17,17 @@ # This file is a part of the vllm-ascend project. from typing import Optional +import torch from torch import nn from transformers import PretrainedConfig from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig +from vllm.config import CacheConfig, CompilationLevel, VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, + get_tp_group) +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -29,13 +35,84 @@ from vllm.model_executor.models.qwen3_moe import (Qwen3MoeAttention, Qwen3MoeDecoderLayer, Qwen3MoeForCausalLM, - Qwen3MoeMLP, Qwen3MoeModel) + Qwen3MoeMLP, Qwen3MoeModel, + Qwen3MoeSparseMoeBlock) from vllm.model_executor.models.utils import ( extract_layer_index, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm_ascend.ops.fused_moe import AscendSparseMoeBlock -from vllm_ascend.platform import VllmConfig +from vllm_ascend.ops.fused_moe import AscendFusedMoE + + +class CustomSparseMoeBlock(Qwen3MoeSparseMoeBlock): + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + nn.Module.__init__(self) + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size > config.num_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.num_experts}.") + + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate", + ) + + self.experts = AscendFusedMoE( + num_experts=config.num_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + prefix=f"{prefix}.experts", + ) + + self.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + + self.params_dtype = torch.get_default_dtype() + + def forward( + self, + hidden_states, + attn_metadata=None, + ): + if attn_metadata is None: + attn_metadata = get_forward_context().attn_metadata + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + enable_force_load_balance = get_forward_context().in_profile_run + is_prefill = get_forward_context().with_prefill + + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + + hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=self.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=None, + ) + + return hidden_states class CustomQwen3MoeDecoderLayer(Qwen3MoeDecoderLayer): @@ -45,6 +122,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + vllm_config: Optional[VllmConfig] = None, prefix: str = "", ) -> None: @@ -73,12 +151,22 @@ def __init__( layer_idx = extract_layer_index(prefix) mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers) + use_aclgraph = (vllm_config is not None + and vllm_config.compilation_config.level + == CompilationLevel.PIECEWISE + and not vllm_config.model_config.enforce_eager) if (layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0): - self.mlp = AscendSparseMoeBlock(config=config, - quant_config=quant_config, - prefix=f"{prefix}.mlp") + if not use_aclgraph: + # FIXME: custom sparse moe block doesn't work with aclgraph. + self.mlp = CustomSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") + else: + self.mlp = Qwen3MoeSparseMoeBlock(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") else: self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, @@ -115,6 +203,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config=config, cache_config=cache_config, quant_config=quant_config, + vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index d4784d4feb..04d288b063 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -22,8 +22,6 @@ import torch.distributed as dist import torch_npu from torch import nn -from transformers import PretrainedConfig -from vllm.attention import AttentionMetadata from vllm.config import get_current_vllm_config from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -37,7 +35,6 @@ FusedMoEParallelConfig # isort: skip from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map) -from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization.base_config import \ QuantizationConfig @@ -1546,79 +1543,3 @@ def _forward_ms_fused_moe_comp( ) return hidden_states - - -class AscendSparseMoeBlock(nn.Module): - - def __init__( - self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ): - super().__init__() - self.tp_size = get_tensor_model_parallel_world_size() - if self.tp_size > config.num_experts: - raise ValueError( - f"Tensor parallel size {self.tp_size} is greater than " - f"the number of experts {config.num_experts}.") - - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - self.enable_multistream_moe = ( - ascend_config.torchair_graph_config.enable_multistream_moe) - - self.gate = ReplicatedLinear( - config.hidden_size, - config.num_experts, - bias=False, - quant_config=None, - prefix=f"{prefix}.gate", - ) - - self.experts = AscendFusedMoE( - num_experts=config.num_experts, - top_k=config.num_experts_per_tok, - hidden_size=config.hidden_size, - intermediate_size=config.moe_intermediate_size, - reduce_results=False, - renormalize=config.norm_topk_prob, - quant_config=quant_config, - prefix=f"{prefix}.experts", - ) - - self.top_k = config.num_experts_per_tok - - self.dp_size = get_dp_group().world_size - - self.tp_group = get_tp_group().device_group - self.tp_rank = get_tp_group().rank_in_group - self.ep_group = get_ep_group() - - self.params_dtype = torch.get_default_dtype() - - def forward( - self, - hidden_states: torch.Tensor, - attn_metadata: Optional[AttentionMetadata] = None, - ) -> torch.Tensor: - if attn_metadata is None: - attn_metadata = get_forward_context().attn_metadata - # when profile runs, force experts to load balanced tokens - # to avoid high memory consumption on a single rank. - enable_force_load_balance = get_forward_context().in_profile_run - is_prefill = get_forward_context().with_prefill - - # router_logits: (num_tokens, n_experts) - router_logits, _ = self.gate(hidden_states) - - hidden_states = self.experts( - hidden_states=hidden_states, - router_logits=router_logits, - is_prefill=is_prefill, - top_k=self.top_k, - enable_force_load_balance=enable_force_load_balance, - shared_experts=None, - ) - - return hidden_states