Skip to content

Conversation

@shen-shanshan
Copy link
Collaborator

@shen-shanshan shen-shanshan commented May 21, 2025

What this PR does / why we need it?

Add support for Qwen3-MoE model.

Does this PR introduce any user-facing change?

no.

How was this patch tested?

TODO:

  • vllm-ascend offline inference test
  • vllm-ascend online inference test
  • Accuracy test @hfadzxy

@shen-shanshan
Copy link
Collaborator Author

Comparing with main branch:

diff --git a/vllm_ascend/models/qwen3_moe.py b/vllm_ascend/models/qwen3_moe.py
index aae5401..b95c3d9 100644
--- a/vllm_ascend/models/qwen3_moe.py
+++ b/vllm_ascend/models/qwen3_moe.py
@@ -1,13 +1,9 @@
-# SPDX-License-Identifier: Apache-2.0
-
-# Copyright 2024 The Qwen team.
+#
+# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
 # Copyright 2023 The vLLM team.
-# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
+# Copyright 2024 The Qwen team.
+# Copyright 2022 EleutherAI and the HuggingFace Inc. team.
 #
-# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
-# and OPT implementations in this library. It has been modified from its
-# original forms to accommodate minor architectural differences compared
-# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -20,18 +16,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-"""Inference-only Qwen3MoE model compatible with HuggingFace weights."""
+# Adapted from vllm/model_executor/models/qwen3_moe.py
+# This file is a part of the vllm-ascend project.
+
 from collections.abc import Iterable
-from typing import Any, Optional, Union
+from typing import Any, List, Optional, Union
 
 import torch
 from torch import nn
 from transformers import PretrainedConfig
-
-from vllm.attention import Attention
+from vllm.attention import Attention, AttentionMetadata
 from vllm.compilation.decorators import support_torch_compile
 from vllm.config import CacheConfig, VllmConfig
-from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
+from vllm.distributed import (get_pp_group,
+                              get_tensor_model_parallel_world_size,
+                              tensor_model_parallel_all_reduce)
 from vllm.logger import init_logger
 from vllm.model_executor.layers.activation import SiluAndMul
 from vllm.model_executor.layers.fused_moe import FusedMoE
@@ -43,18 +42,17 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
 from vllm.model_executor.layers.logits_processor import LogitsProcessor
 from vllm.model_executor.layers.quantization import QuantizationConfig
 from vllm.model_executor.layers.rotary_embedding import get_rope
+from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
 from vllm.model_executor.layers.vocab_parallel_embedding import (
     ParallelLMHead, VocabParallelEmbedding)
 from vllm.model_executor.model_loader.weight_utils import default_weight_loader
+from vllm.model_executor.models.interfaces import SupportsPP
+from vllm.model_executor.models.utils import (
+    extract_layer_index, is_pp_missing_parameter,
+    make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
 from vllm.model_executor.sampling_metadata import SamplingMetadata
 from vllm.sequence import IntermediateTensors
 
-from .interfaces import SupportsPP
-from .utils import (AutoWeightsLoader, extract_layer_index,
-                    is_pp_missing_parameter,
-                    make_empty_intermediate_tensors_factory, make_layers,
-                    maybe_prefix)
-
 logger = init_logger(__name__)
 
 
@@ -99,7 +97,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
         self,
         config: PretrainedConfig,
         quant_config: Optional[QuantizationConfig] = None,
-        prefix: str = "",
     ):
         super().__init__()
         self.tp_size = get_tensor_model_parallel_world_size()
@@ -115,14 +112,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
                                 intermediate_size=config.moe_intermediate_size,
                                 reduce_results=False,
                                 renormalize=config.norm_topk_prob,
-                                quant_config=quant_config,
-                                prefix=f"{prefix}.experts")
+                                quant_config=quant_config)
 
         self.gate = ReplicatedLinear(config.hidden_size,
                                      config.num_experts,
                                      bias=False,
-                                     quant_config=None,
-                                     prefix=f"{prefix}.gate")
+                                     quant_config=None)
 
     def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
         # NOTE: hidden_states can have either 1D or 2D shape.
@@ -136,7 +131,7 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
                                            router_logits=router_logits)
         final_hidden_states = final_hidden_states
         if self.tp_size > 1:
-            final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel(  # noqa E501
+            final_hidden_states = tensor_model_parallel_all_reduce(
                 final_hidden_states)
 
         return final_hidden_states.view(orig_shape)
@@ -218,6 +213,8 @@ class Qwen3MoeAttention(nn.Module):
         self,
         positions: torch.Tensor,
         hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
     ) -> torch.Tensor:
         qkv, _ = self.qkv_proj(hidden_states)
         q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
@@ -232,7 +229,7 @@ class Qwen3MoeAttention(nn.Module):
         k_by_head = self.k_norm(k_by_head)
         k = k_by_head.view(k.shape)
         q, k = self.rotary_emb(positions, q, k)
-        attn_output = self.attn(q, k, v)
+        attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
         output, _ = self.o_proj(attn_output)
         return output
 
@@ -275,8 +272,7 @@ class Qwen3MoeDecoderLayer(nn.Module):
                 config.num_experts > 0 and
             (layer_idx + 1) % config.decoder_sparse_step == 0):
             self.mlp = Qwen3MoeSparseMoeBlock(config=config,
-                                              quant_config=quant_config,
-                                              prefix=f"{prefix}.mlp")
+                                              quant_config=quant_config)
         else:
             self.mlp = Qwen3MoeMLP(hidden_size=config.hidden_size,
                                    intermediate_size=config.intermediate_size,
@@ -292,6 +288,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
         self,
         positions: torch.Tensor,
         hidden_states: torch.Tensor,
+        kv_cache: torch.Tensor,
+        attn_metadata: AttentionMetadata,
         residual: Optional[torch.Tensor],
     ) -> torch.Tensor:
         # Self Attention
@@ -304,6 +302,8 @@ class Qwen3MoeDecoderLayer(nn.Module):
         hidden_states = self.self_attn(
             positions=positions,
             hidden_states=hidden_states,
+            kv_cache=kv_cache,
+            attn_metadata=attn_metadata,
         )
 
         # Fully Connected
@@ -325,7 +325,7 @@ class Qwen3MoeModel(nn.Module):
 
         self.padding_idx = config.pad_token_id
         self.vocab_size = config.vocab_size
-        self.config = config
+        # self.config = config
         self.embed_tokens = VocabParallelEmbedding(
             config.vocab_size,
             config.hidden_size,
@@ -350,6 +350,8 @@ class Qwen3MoeModel(nn.Module):
         self,
         input_ids: torch.Tensor,
         positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
         intermediate_tensors: Optional[IntermediateTensors] = None,
         inputs_embeds: Optional[torch.Tensor] = None,
     ) -> Union[torch.Tensor, IntermediateTensors]:
@@ -365,7 +367,9 @@ class Qwen3MoeModel(nn.Module):
             residual = intermediate_tensors["residual"]
         for i in range(self.start_layer, self.end_layer):
             layer = self.layers[i]
-            hidden_states, residual = layer(positions, hidden_states, residual)
+            hidden_states, residual = layer(positions, hidden_states,
+                                            kv_caches[i - self.start_layer],
+                                            attn_metadata, residual)
         if not get_pp_group().is_last_rank:
             return IntermediateTensors({
                 "hidden_states": hidden_states,
@@ -374,6 +378,74 @@ class Qwen3MoeModel(nn.Module):
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
 
+
+class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
+    packed_modules_mapping = {
+        "qkv_proj": [
+            "q_proj",
+            "k_proj",
+            "v_proj",
+        ],
+        "gate_up_proj": [
+            "gate_proj",
+            "up_proj",
+        ],
+    }
+
+    fall_back_to_pt_during_load = False
+
+    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
+        super().__init__()
+        config = vllm_config.model_config.hf_config
+        quant_config = vllm_config.quant_config
+        self.config = config
+        self.quant_config = quant_config
+        self.model = Qwen3MoeModel(vllm_config=vllm_config,
+                                   prefix=maybe_prefix(prefix, "model"))
+        self.lm_head = ParallelLMHead(config.vocab_size,
+                                      config.hidden_size,
+                                      quant_config=quant_config)
+        if self.config.tie_word_embeddings:
+            self.lm_head.weight = self.model.embed_tokens.weight
+        self.logits_processor = LogitsProcessor(config.vocab_size)
+        self.sampler = get_sampler()
+        self.make_empty_intermediate_tensors = (
+            self.model.make_empty_intermediate_tensors)
+
+    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
+        return self.model.get_input_embeddings(input_ids)
+
+    def forward(
+        self,
+        input_ids: torch.Tensor,
+        positions: torch.Tensor,
+        kv_caches: List[torch.Tensor],
+        attn_metadata: AttentionMetadata,
+        intermediate_tensors: Optional[IntermediateTensors] = None,
+        inputs_embeds: Optional[torch.Tensor] = None,
+    ) -> Union[torch.Tensor, IntermediateTensors]:
+        hidden_states = self.model(input_ids, positions, kv_caches,
+                                   attn_metadata, intermediate_tensors,
+                                   inputs_embeds)
+        return hidden_states
+
+    def compute_logits(
+        self,
+        hidden_states: torch.Tensor,
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[torch.Tensor]:
+        logits = self.logits_processor(self.lm_head, hidden_states,
+                                       sampling_metadata)
+        return logits
+
+    def sample(
+        self,
+        logits: Optional[torch.Tensor],
+        sampling_metadata: SamplingMetadata,
+    ) -> Optional[SamplerOutput]:
+        next_tokens = self.sampler(logits, sampling_metadata)
+        return next_tokens
+
     def load_weights(self, weights: Iterable[tuple[str,
                                                    torch.Tensor]]) -> set[str]:
         stacked_params_mapping = [
@@ -471,67 +543,3 @@ class Qwen3MoeModel(nn.Module):
                     weight_loader(param, loaded_weight)
             loaded_params.add(name)
         return loaded_params
-
-
-class Qwen3MoeForCausalLM(nn.Module, SupportsPP):
-    packed_modules_mapping = {
-        "qkv_proj": [
-            "q_proj",
-            "k_proj",
-            "v_proj",
-        ],
-        "gate_up_proj": [
-            "gate_proj",
-            "up_proj",
-        ],
-    }
-
-    fall_back_to_pt_during_load = False
-
-    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
-        super().__init__()
-        config = vllm_config.model_config.hf_config
-        quant_config = vllm_config.quant_config
-        self.config = config
-        self.quant_config = quant_config
-        self.model = Qwen3MoeModel(vllm_config=vllm_config,
-                                   prefix=maybe_prefix(prefix, "model"))
-        self.lm_head = ParallelLMHead(config.vocab_size,
-                                      config.hidden_size,
-                                      quant_config=quant_config)
-        if self.config.tie_word_embeddings:
-            self.lm_head.weight = self.model.embed_tokens.weight
-        self.logits_processor = LogitsProcessor(config.vocab_size)
-        self.make_empty_intermediate_tensors = (
-            self.model.make_empty_intermediate_tensors)
-
-    def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
-        return self.model.get_input_embeddings(input_ids)
-
-    def forward(
-        self,
-        input_ids: torch.Tensor,
-        positions: torch.Tensor,
-        intermediate_tensors: Optional[IntermediateTensors] = None,
-        inputs_embeds: Optional[torch.Tensor] = None,
-    ) -> Union[torch.Tensor, IntermediateTensors]:
-        hidden_states = self.model(input_ids, positions, intermediate_tensors,
-                                   inputs_embeds)
-        return hidden_states
-
-    def compute_logits(
-        self,
-        hidden_states: torch.Tensor,
-        sampling_metadata: SamplingMetadata,
-    ) -> Optional[torch.Tensor]:
-        logits = self.logits_processor(self.lm_head, hidden_states,
-                                       sampling_metadata)
-        return logits
-
-    def load_weights(self, weights: Iterable[tuple[str,
-                                                   torch.Tensor]]) -> set[str]:
-        loader = AutoWeightsLoader(
-            self,
-            skip_prefixes=(["rotary_emb.inv_freq"]),
-        )
-        return loader.load_weights(weights)

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
Copy link
Collaborator

@wangxiyuan wangxiyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work.

@wangxiyuan wangxiyuan added the ready read for review label May 22, 2025
@wangxiyuan
Copy link
Collaborator

Please doubel check that deepseek works as expect as well.

Signed-off-by: shen-shanshan <467638484@qq.com>
@zhangxinyuehfad
Copy link
Contributor

  • Accuracy test @hfadzxy

The accuracy of Qwen/Qwen3-30B-A3B with gsm8k datasets is low:

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.2168 ± 0.0114
strict-match 5 exact_match 0.0144 ± 0.0033

Signed-off-by: shen-shanshan <467638484@qq.com>
Signed-off-by: shen-shanshan <467638484@qq.com>
@wangxiyuan wangxiyuan merged commit 3ce6de4 into vllm-project:v0.7.3-dev May 23, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module:ops ready read for review

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants