1+ from collections .abc import Iterable
2+ from typing import Optional , Union
3+
4+ import torch
5+ from torch import nn
6+ from transformers import Qwen3Config
7+
8+ from vllm .attention import AttentionType
9+ from vllm .compilation .decorators import support_torch_compile
10+ from vllm .config import CacheConfig , VllmConfig
11+ from vllm .distributed import get_pp_group
12+ from vllm .model_executor .layers .layernorm import RMSNorm
13+ from vllm .model_executor .layers .logits_processor import LogitsProcessor
14+ from vllm .model_executor .layers .quantization import QuantizationConfig
15+ from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
16+ from vllm .model_executor .sampling_metadata import SamplingMetadata
17+ from vllm .sequence import IntermediateTensors
18+
19+ from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
20+ from vllm .model_executor .models .qwen2 import Qwen2MLP as Qwen3MLP
21+ from vllm .model_executor .models .qwen2 import Qwen2Model
22+ from vllm .model_executor .models .qwen3 import Qwen3Attention
23+ from vllm .model_executor .models .utils import AutoWeightsLoader , PPMissingLayer , maybe_prefix
24+
25+ from vllm_ascend .ops .layernorm import AddRMSNormQuant
26+
27+ class CustomQwen3DecoderLayer (nn .Module ):
28+
29+ def __init__ (
30+ self ,
31+ config : Qwen3Config ,
32+ cache_config : Optional [CacheConfig ] = None ,
33+ quant_config : Optional [QuantizationConfig ] = None ,
34+ prefix : str = "" ,
35+ ) -> None :
36+ super ().__init__ ()
37+ self .hidden_size = config .hidden_size
38+ # Requires transformers > 4.32.0
39+ rope_theta = getattr (config , "rope_theta" , 1000000 )
40+ rope_scaling = getattr (config , "rope_scaling" , None )
41+
42+ # By default, Qwen3 uses causal attention as it is a decoder-only model.
43+ # You can override the HF config with `is_causal=False` to enable
44+ # bidirectional attention, which is used in some embedding models
45+ # (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
46+ if getattr (config , "is_causal" , True ):
47+ attn_type = AttentionType .DECODER
48+ else :
49+ attn_type = AttentionType .ENCODER_ONLY
50+
51+ self .self_attn = Qwen3Attention (
52+ hidden_size = self .hidden_size ,
53+ num_heads = config .num_attention_heads ,
54+ max_position = config .max_position_embeddings ,
55+ num_kv_heads = config .num_key_value_heads ,
56+ rope_theta = rope_theta ,
57+ rms_norm_eps = config .rms_norm_eps ,
58+ qkv_bias = getattr (config , 'attention_bias' , False ),
59+ head_dim = getattr (config , 'head_dim' , None ),
60+ cache_config = cache_config ,
61+ quant_config = quant_config ,
62+ rope_scaling = rope_scaling ,
63+ prefix = f"{ prefix } .self_attn" ,
64+ attn_type = attn_type ,
65+ )
66+ self .mlp = Qwen3MLP (
67+ hidden_size = self .hidden_size ,
68+ intermediate_size = config .intermediate_size ,
69+ hidden_act = config .hidden_act ,
70+ quant_config = quant_config ,
71+ prefix = f"{ prefix } .mlp" ,
72+ )
73+ if quant_config is None :
74+ self .input_layernorm = RMSNorm (config .hidden_size ,
75+ eps = config .rms_norm_eps )
76+ self .post_attention_layernorm = RMSNorm (config .hidden_size ,
77+ eps = config .rms_norm_eps )
78+ else :
79+ from vllm_ascend .quantization .quant_config import AscendQuantConfig
80+ assert isinstance (quant_config , AscendQuantConfig )
81+ self .input_layernorm = AddRMSNormQuant (config .hidden_size ,
82+ layer = self .self_attn .qkv_proj ,
83+ eps = config .rms_norm_eps )
84+ self .post_attention_layernorm = AddRMSNormQuant (config .hidden_size ,
85+ layer = self .mlp .gate_up_proj ,
86+ eps = config .rms_norm_eps )
87+
88+ def forward (
89+ self ,
90+ positions : torch .Tensor ,
91+ hidden_states : torch .Tensor ,
92+ residual : Optional [torch .Tensor ],
93+ ) -> tuple [torch .Tensor , torch .Tensor ]:
94+ # Self Attention
95+ if residual is None :
96+ residual = hidden_states
97+ hidden_states = self .input_layernorm (hidden_states )
98+ else :
99+ hidden_states , residual = self .input_layernorm (
100+ hidden_states , residual )
101+ hidden_states = self .self_attn (
102+ positions = positions ,
103+ hidden_states = hidden_states ,
104+ )
105+ hidden_states , residual = self .post_attention_layernorm (
106+ hidden_states , residual )
107+ hidden_states = self .mlp (hidden_states )
108+ return hidden_states , residual
109+
110+
111+ ALL_DECODER_LAYER_TYPES = {
112+ "attention" : CustomQwen3DecoderLayer ,
113+ }
114+
115+
116+ @support_torch_compile (
117+ dynamic_arg_dims = {
118+ "input_ids" : 0 ,
119+ # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
120+ # otherwise (seq_len, ).
121+ "positions" : - 1 ,
122+ "intermediate_tensors" : 0 ,
123+ "inputs_embeds" : 0 ,
124+ })
125+ class CustomQwen3Model (Qwen2Model ):
126+
127+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
128+ super ().__init__ (vllm_config = vllm_config ,
129+ prefix = prefix ,
130+ decoder_layer_type = CustomQwen3DecoderLayer )
131+
132+
133+ class CustomQwen3ForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
134+ # add `CustomQwen3Model` to init self.model
135+ packed_modules_mapping = {
136+ "qkv_proj" : [
137+ "q_proj" ,
138+ "k_proj" ,
139+ "v_proj" ,
140+ ],
141+ "gate_up_proj" : [
142+ "gate_proj" ,
143+ "up_proj" ,
144+ ],
145+ }
146+
147+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
148+ super ().__init__ ()
149+ config = vllm_config .model_config .hf_config
150+ quant_config = vllm_config .quant_config
151+ lora_config = vllm_config .lora_config
152+
153+ self .config = config
154+ self .lora_config = lora_config
155+
156+ self .quant_config = quant_config
157+ self .model = CustomQwen3Model (vllm_config = vllm_config ,
158+ prefix = maybe_prefix (prefix , "model" ))
159+
160+ if get_pp_group ().is_last_rank :
161+ if config .tie_word_embeddings :
162+ self .lm_head = self .model .embed_tokens
163+ else :
164+ self .lm_head = ParallelLMHead (config .vocab_size ,
165+ config .hidden_size ,
166+ quant_config = quant_config ,
167+ prefix = maybe_prefix (
168+ prefix , "lm_head" ))
169+ else :
170+ self .lm_head = PPMissingLayer ()
171+
172+ self .logits_processor = LogitsProcessor (config .vocab_size )
173+
174+ self .make_empty_intermediate_tensors = (
175+ self .model .make_empty_intermediate_tensors )
176+
177+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
178+ return self .model .get_input_embeddings (input_ids )
179+
180+ def forward (
181+ self ,
182+ input_ids : torch .Tensor ,
183+ positions : torch .Tensor ,
184+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
185+ inputs_embeds : Optional [torch .Tensor ] = None ,
186+ ) -> Union [torch .Tensor , IntermediateTensors ]:
187+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
188+ inputs_embeds )
189+ return hidden_states
190+
191+ def compute_logits (
192+ self ,
193+ hidden_states : torch .Tensor ,
194+ sampling_metadata : SamplingMetadata ,
195+ ) -> Optional [torch .Tensor ]:
196+ logits = self .logits_processor (self .lm_head , hidden_states ,
197+ sampling_metadata )
198+ return logits
199+
200+ def load_weights (self , weights : Iterable [tuple [str ,
201+ torch .Tensor ]]) -> set [str ]:
202+ loader = AutoWeightsLoader (
203+ self ,
204+ skip_prefixes = (["lm_head." ]
205+ if self .config .tie_word_embeddings else None ),
206+ )
207+ return loader .load_weights (weights )
0 commit comments