1+ from collections .abc import Iterable
2+ from typing import Optional , Union
3+
4+ import torch
5+ from torch import nn
6+ from transformers import Qwen3Config
7+
8+ from vllm .attention import Attention , AttentionType
9+ from vllm .compilation .decorators import support_torch_compile
10+ from vllm .config import CacheConfig , VllmConfig
11+ from vllm .distributed import get_pp_group
12+ from vllm .model_executor .layers .layernorm import RMSNorm
13+ from vllm .model_executor .layers .logits_processor import LogitsProcessor
14+ from vllm .model_executor .layers .quantization import QuantizationConfig
15+ from vllm .model_executor .layers .vocab_parallel_embedding import ParallelLMHead
16+ from vllm .model_executor .sampling_metadata import SamplingMetadata
17+ from vllm .sequence import IntermediateTensors
18+
19+ from vllm .model_executor .models .interfaces import SupportsLoRA , SupportsPP
20+ from vllm .model_executor .models .qwen2 import Qwen2MLP as Qwen3MLP
21+ from vllm .model_executor .models .qwen2 import Qwen2Model
22+ from vllm .model_executor .models .qwen3 import Qwen3ForCausalLM , Qwen3Attention
23+ from vllm .model_executor .models .utils import AutoWeightsLoader , PPMissingLayer , maybe_prefix
24+
25+ from vllm_ascend .ops .layernorm import AddRMSNormQuant
26+
27+ class CustomQwen3DecoderLayer (nn .Module ):
28+
29+ def __init__ (
30+ self ,
31+ config : Qwen3Config ,
32+ cache_config : Optional [CacheConfig ] = None ,
33+ quant_config : Optional [QuantizationConfig ] = None ,
34+ prefix : str = "" ,
35+ ) -> None :
36+ super ().__init__ ()
37+ self .hidden_size = config .hidden_size
38+ # Requires transformers > 4.32.0
39+ rope_theta = getattr (config , "rope_theta" , 1000000 )
40+ rope_scaling = getattr (config , "rope_scaling" , None )
41+
42+ # By default, Qwen3 uses causal attention as it is a decoder-only model.
43+ # You can override the HF config with `is_causal=False` to enable
44+ # bidirectional attention, which is used in some embedding models
45+ # (e.g. Alibaba-NLP/gte-Qwen3-7B-instruct)
46+ if getattr (config , "is_causal" , True ):
47+ attn_type = AttentionType .DECODER
48+ else :
49+ attn_type = AttentionType .ENCODER_ONLY
50+
51+ self .self_attn = Qwen3Attention (
52+ hidden_size = self .hidden_size ,
53+ num_heads = config .num_attention_heads ,
54+ max_position = config .max_position_embeddings ,
55+ num_kv_heads = config .num_key_value_heads ,
56+ rope_theta = rope_theta ,
57+ rms_norm_eps = config .rms_norm_eps ,
58+ qkv_bias = getattr (config , 'attention_bias' , False ),
59+ head_dim = getattr (config , 'head_dim' , None ),
60+ cache_config = cache_config ,
61+ quant_config = quant_config ,
62+ rope_scaling = rope_scaling ,
63+ prefix = f"{ prefix } .self_attn" ,
64+ attn_type = attn_type ,
65+ )
66+ self .mlp = Qwen3MLP (
67+ hidden_size = self .hidden_size ,
68+ intermediate_size = config .intermediate_size ,
69+ hidden_act = config .hidden_act ,
70+ quant_config = quant_config ,
71+ prefix = f"{ prefix } .mlp" ,
72+ )
73+ if quant_config is None :
74+ self .input_layernorm = RMSNorm (config .hidden_size ,
75+ eps = config .rms_norm_eps )
76+ self .post_attention_layernorm = RMSNorm (config .hidden_size ,
77+ eps = config .rms_norm_eps )
78+ else :
79+ from vllm_ascend .quantization .quant_config import AscendQuantConfig
80+ assert isinstance (quant_config , AscendQuantConfig )
81+ self .input_layernorm = AddRMSNormQuant (config .hidden_size ,
82+ self .self_attn .qkv_proj .aclnn_input_scale ,
83+ self .self_attn .qkv_proj .aclnn_input_offset ,
84+ eps = config .rms_norm_eps )
85+ self .post_attention_layernorm = AddRMSNormQuant (config .hidden_size ,
86+ self .mlp .gate_up_proj .aclnn_input_scale ,
87+ self .mlp .gate_up_proj .aclnn_input_offset ,
88+ eps = config .rms_norm_eps )
89+
90+ def forward (
91+ self ,
92+ positions : torch .Tensor ,
93+ hidden_states : torch .Tensor ,
94+ residual : Optional [torch .Tensor ],
95+ ) -> tuple [torch .Tensor , torch .Tensor ]:
96+ # Self Attention
97+ if residual is None :
98+ residual = hidden_states
99+ hidden_states = self .input_layernorm (hidden_states )
100+ else :
101+ hidden_states , residual = self .input_layernorm (
102+ hidden_states , residual )
103+ hidden_states = self .self_attn (
104+ positions = positions ,
105+ hidden_states = hidden_states ,
106+ )
107+ hidden_states , residual = self .post_attention_layernorm (
108+ hidden_states , residual )
109+ hidden_states = self .mlp (hidden_states )
110+ return hidden_states , residual
111+
112+
113+ ALL_DECODER_LAYER_TYPES = {
114+ "attention" : CustomQwen3DecoderLayer ,
115+ }
116+
117+
118+ @support_torch_compile (
119+ dynamic_arg_dims = {
120+ "input_ids" : 0 ,
121+ # positions is of shape (3, seq_len) if mrope is enabled for qwen2-vl,
122+ # otherwise (seq_len, ).
123+ "positions" : - 1 ,
124+ "intermediate_tensors" : 0 ,
125+ "inputs_embeds" : 0 ,
126+ })
127+ class CustomQwen3Model (Qwen2Model ):
128+
129+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
130+ super ().__init__ (vllm_config = vllm_config ,
131+ prefix = prefix ,
132+ decoder_layer_type = CustomQwen3DecoderLayer )
133+
134+
135+ class CustomQwen3ForCausalLM (nn .Module , SupportsLoRA , SupportsPP ):
136+ # add `CustomQwen3Model` to init self.model
137+ packed_modules_mapping = {
138+ "qkv_proj" : [
139+ "q_proj" ,
140+ "k_proj" ,
141+ "v_proj" ,
142+ ],
143+ "gate_up_proj" : [
144+ "gate_proj" ,
145+ "up_proj" ,
146+ ],
147+ }
148+
149+ def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
150+ super ().__init__ ()
151+ config = vllm_config .model_config .hf_config
152+ quant_config = vllm_config .quant_config
153+ lora_config = vllm_config .lora_config
154+
155+ self .config = config
156+ self .lora_config = lora_config
157+
158+ self .quant_config = quant_config
159+ self .model = CustomQwen3Model (vllm_config = vllm_config ,
160+ prefix = maybe_prefix (prefix , "model" ))
161+
162+ if get_pp_group ().is_last_rank :
163+ if config .tie_word_embeddings :
164+ self .lm_head = self .model .embed_tokens
165+ else :
166+ self .lm_head = ParallelLMHead (config .vocab_size ,
167+ config .hidden_size ,
168+ quant_config = quant_config ,
169+ prefix = maybe_prefix (
170+ prefix , "lm_head" ))
171+ else :
172+ self .lm_head = PPMissingLayer ()
173+
174+ self .logits_processor = LogitsProcessor (config .vocab_size )
175+
176+ self .make_empty_intermediate_tensors = (
177+ self .model .make_empty_intermediate_tensors )
178+
179+ def get_input_embeddings (self , input_ids : torch .Tensor ) -> torch .Tensor :
180+ return self .model .get_input_embeddings (input_ids )
181+
182+ def forward (
183+ self ,
184+ input_ids : torch .Tensor ,
185+ positions : torch .Tensor ,
186+ intermediate_tensors : Optional [IntermediateTensors ] = None ,
187+ inputs_embeds : Optional [torch .Tensor ] = None ,
188+ ) -> Union [torch .Tensor , IntermediateTensors ]:
189+ hidden_states = self .model (input_ids , positions , intermediate_tensors ,
190+ inputs_embeds )
191+ return hidden_states
192+
193+ def compute_logits (
194+ self ,
195+ hidden_states : torch .Tensor ,
196+ sampling_metadata : SamplingMetadata ,
197+ ) -> Optional [torch .Tensor ]:
198+ logits = self .logits_processor (self .lm_head , hidden_states ,
199+ sampling_metadata )
200+ return logits
201+
202+ def load_weights (self , weights : Iterable [tuple [str ,
203+ torch .Tensor ]]) -> set [str ]:
204+ loader = AutoWeightsLoader (
205+ self ,
206+ skip_prefixes = (["lm_head." ]
207+ if self .config .tie_word_embeddings else None ),
208+ )
209+ return loader .load_weights (weights )
0 commit comments