1515# limitations under the License.
1616# Adapted from vllm/model_executor/models/qwen3_moe.py
1717# This file is a part of the vllm-ascend project.
18- from typing import Optional
1918
20- from torch import nn
21- from transformers import PretrainedConfig
22- from vllm .compilation .decorators import support_torch_compile
23- from vllm .config import CacheConfig
24- from vllm .model_executor .layers .layernorm import RMSNorm
25- from vllm .model_executor .layers .logits_processor import LogitsProcessor
26- from vllm .model_executor .layers .quantization import QuantizationConfig
27- from vllm .model_executor .layers .vocab_parallel_embedding import (
28- ParallelLMHead , VocabParallelEmbedding )
29- from vllm .model_executor .models .qwen3_moe import (Qwen3MoeAttention ,
30- Qwen3MoeDecoderLayer ,
31- Qwen3MoeForCausalLM ,
32- Qwen3MoeMLP , Qwen3MoeModel )
33- from vllm .model_executor .models .utils import (
34- extract_layer_index , make_empty_intermediate_tensors_factory , make_layers ,
35- maybe_prefix )
36-
37- from vllm_ascend .ops .fused_moe import AscendSparseMoeBlock
38- from vllm_ascend .platform import VllmConfig
39-
40-
41- class CustomQwen3MoeDecoderLayer (Qwen3MoeDecoderLayer ):
42-
43- def __init__ (
44- self ,
45- config : PretrainedConfig ,
46- cache_config : Optional [CacheConfig ] = None ,
47- quant_config : Optional [QuantizationConfig ] = None ,
48- prefix : str = "" ,
49- ) -> None :
50-
51- nn .Module .__init__ (self )
52- self .hidden_size = config .hidden_size
53- rope_theta = getattr (config , "rope_theta" , 10000 )
54- rope_scaling = getattr (config , "rope_scaling" , None )
55- max_position_embeddings = getattr (config , "max_position_embeddings" ,
56- 8192 )
57- self .self_attn = Qwen3MoeAttention (
58- hidden_size = self .hidden_size ,
59- num_heads = config .num_attention_heads ,
60- num_kv_heads = config .num_key_value_heads ,
61- rope_theta = rope_theta ,
62- rope_scaling = rope_scaling ,
63- max_position_embeddings = max_position_embeddings ,
64- rms_norm_eps = config .rms_norm_eps ,
65- qkv_bias = getattr (config , 'attention_bias' , False ),
66- head_dim = getattr (config , 'head_dim' , None ),
67- cache_config = cache_config ,
68- quant_config = quant_config ,
69- prefix = f"{ prefix } .self_attn" ,
70- )
71-
72- # `mlp_only_layers` in the config.
73- layer_idx = extract_layer_index (prefix )
74- mlp_only_layers = ([] if not hasattr (config , "mlp_only_layers" ) else
75- config .mlp_only_layers )
76- if (layer_idx not in mlp_only_layers ) and (
77- config .num_experts > 0 and
78- (layer_idx + 1 ) % config .decoder_sparse_step == 0 ):
79- self .mlp = AscendSparseMoeBlock (config = config ,
80- quant_config = quant_config ,
81- prefix = f"{ prefix } .mlp" )
82- else :
83- self .mlp = Qwen3MoeMLP (hidden_size = config .hidden_size ,
84- intermediate_size = config .intermediate_size ,
85- hidden_act = config .hidden_act ,
86- quant_config = quant_config ,
87- prefix = f"{ prefix } .mlp" )
88- self .input_layernorm = RMSNorm (config .hidden_size ,
89- eps = config .rms_norm_eps )
90- self .post_attention_layernorm = RMSNorm (config .hidden_size ,
91- eps = config .rms_norm_eps )
92-
93-
94- @support_torch_compile
95- class CustomQwen3MoeModel (Qwen3MoeModel ):
96-
97- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
98- nn .Module .__init__ (self )
99- config = vllm_config .model_config .hf_config
100- cache_config = vllm_config .cache_config
101- quant_config = vllm_config .quant_config
102-
103- self .padding_idx = config .pad_token_id
104- self .vocab_size = config .vocab_size
105- self .config = config
106- self .embed_tokens = VocabParallelEmbedding (
107- config .vocab_size ,
108- config .hidden_size ,
109- prefix = f"{ prefix } .embed_tokens" )
110- self .start_layer , self .end_layer , self .layers = make_layers (
111- config .num_hidden_layers ,
112- lambda prefix : CustomQwen3MoeDecoderLayer (
113- config = config ,
114- cache_config = cache_config ,
115- quant_config = quant_config ,
116- prefix = prefix ),
117- prefix = f"{ prefix } .layers" ,
118- )
119- self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps )
120- self .make_empty_intermediate_tensors = (
121- make_empty_intermediate_tensors_factory (
122- ["hidden_states" , "residual" ], config .hidden_size ))
19+ from vllm .model_executor .models .qwen3_moe import Qwen3MoeForCausalLM
12320
12421
12522class CustomQwen3MoeForCausalLM (Qwen3MoeForCausalLM ):
@@ -136,20 +33,3 @@ class CustomQwen3MoeForCausalLM(Qwen3MoeForCausalLM):
13633 "experts" :
13734 ["experts.0.gate_proj" , "experts.0.up_proj" , "experts.0.down_proj" ],
13835 }
139-
140- def __init__ (self , * , vllm_config : VllmConfig , prefix : str = "" ):
141- nn .Module .__init__ (self )
142- config = vllm_config .model_config .hf_config
143- quant_config = vllm_config .quant_config
144- self .config = config
145- self .quant_config = quant_config
146- self .model = CustomQwen3MoeModel (vllm_config = vllm_config ,
147- prefix = maybe_prefix (prefix , "model" ))
148- self .lm_head = ParallelLMHead (config .vocab_size ,
149- config .hidden_size ,
150- quant_config = quant_config )
151- if self .config .tie_word_embeddings :
152- self .lm_head .weight = self .model .embed_tokens .weight
153- self .logits_processor = LogitsProcessor (config .vocab_size )
154- self .make_empty_intermediate_tensors = (
155- self .model .make_empty_intermediate_tensors )
0 commit comments