1717# This file is a part of the vllm-ascend project. 
1818from  typing  import  Optional 
1919
20+ import  torch 
2021from  torch  import  nn 
2122from  transformers  import  PretrainedConfig 
2223from  vllm .compilation .decorators  import  support_torch_compile 
23- from  vllm .config  import  CacheConfig 
24+ from  vllm .config  import  CacheConfig , CompilationLevel , VllmConfig 
25+ from  vllm .distributed  import  get_tensor_model_parallel_world_size 
26+ from  vllm .distributed .parallel_state  import  (get_dp_group , get_ep_group ,
27+                                              get_tp_group )
28+ from  vllm .forward_context  import  get_forward_context 
2429from  vllm .model_executor .layers .layernorm  import  RMSNorm 
30+ from  vllm .model_executor .layers .linear  import  ReplicatedLinear 
2531from  vllm .model_executor .layers .logits_processor  import  LogitsProcessor 
2632from  vllm .model_executor .layers .quantization  import  QuantizationConfig 
2733from  vllm .model_executor .layers .vocab_parallel_embedding  import  (
2834    ParallelLMHead , VocabParallelEmbedding )
2935from  vllm .model_executor .models .qwen3_moe  import  (Qwen3MoeAttention ,
3036                                                  Qwen3MoeDecoderLayer ,
3137                                                  Qwen3MoeForCausalLM ,
32-                                                   Qwen3MoeMLP , Qwen3MoeModel )
38+                                                   Qwen3MoeMLP , Qwen3MoeModel ,
39+                                                   Qwen3MoeSparseMoeBlock )
3340from  vllm .model_executor .models .utils  import  (
3441    extract_layer_index , make_empty_intermediate_tensors_factory , make_layers ,
3542    maybe_prefix )
3643
37- from  vllm_ascend .ops .fused_moe  import  AscendSparseMoeBlock 
38- from  vllm_ascend .platform  import  VllmConfig 
44+ from  vllm_ascend .ops .fused_moe  import  AscendFusedMoE 
45+ 
46+ 
47+ class  CustomSparseMoeBlock (Qwen3MoeSparseMoeBlock ):
48+ 
49+     def  __init__ (
50+         self ,
51+         config : PretrainedConfig ,
52+         quant_config : Optional [QuantizationConfig ] =  None ,
53+         prefix : str  =  "" ,
54+     ):
55+         nn .Module .__init__ (self )
56+         self .tp_size  =  get_tensor_model_parallel_world_size ()
57+         if  self .tp_size  >  config .num_experts :
58+             raise  ValueError (
59+                 f"Tensor parallel size { self .tp_size }   is greater than " 
60+                 f"the number of experts { config .num_experts }  ." )
61+ 
62+         self .gate  =  ReplicatedLinear (
63+             config .hidden_size ,
64+             config .num_experts ,
65+             bias = False ,
66+             quant_config = None ,
67+             prefix = f"{ prefix }  .gate" ,
68+         )
69+ 
70+         self .experts  =  AscendFusedMoE (
71+             num_experts = config .num_experts ,
72+             top_k = config .num_experts_per_tok ,
73+             hidden_size = config .hidden_size ,
74+             intermediate_size = config .moe_intermediate_size ,
75+             reduce_results = False ,
76+             renormalize = config .norm_topk_prob ,
77+             quant_config = quant_config ,
78+             prefix = f"{ prefix }  .experts" ,
79+         )
80+ 
81+         self .top_k  =  config .num_experts_per_tok 
82+ 
83+         self .dp_size  =  get_dp_group ().world_size 
84+ 
85+         self .tp_group  =  get_tp_group ().device_group 
86+         self .tp_rank  =  get_tp_group ().rank_in_group 
87+         self .ep_group  =  get_ep_group ()
88+ 
89+         self .params_dtype  =  torch .get_default_dtype ()
90+ 
91+     def  forward (
92+         self ,
93+         hidden_states ,
94+         attn_metadata = None ,
95+     ):
96+         if  attn_metadata  is  None :
97+             attn_metadata  =  get_forward_context ().attn_metadata 
98+         # when profile runs, force experts to load balanced tokens 
99+         # to avoid high memory consumption on a single rank. 
100+         enable_force_load_balance  =  get_forward_context ().in_profile_run 
101+         is_prefill  =  get_forward_context ().with_prefill 
102+ 
103+         # router_logits: (num_tokens, n_experts) 
104+         router_logits , _  =  self .gate (hidden_states )
105+ 
106+         hidden_states  =  self .experts (
107+             hidden_states = hidden_states ,
108+             router_logits = router_logits ,
109+             is_prefill = is_prefill ,
110+             top_k = self .top_k ,
111+             enable_force_load_balance = enable_force_load_balance ,
112+             shared_experts = None ,
113+         )
114+ 
115+         return  hidden_states 
39116
40117
41118class  CustomQwen3MoeDecoderLayer (Qwen3MoeDecoderLayer ):
@@ -45,6 +122,7 @@ def __init__(
45122        config : PretrainedConfig ,
46123        cache_config : Optional [CacheConfig ] =  None ,
47124        quant_config : Optional [QuantizationConfig ] =  None ,
125+         vllm_config : Optional [VllmConfig ] =  None ,
48126        prefix : str  =  "" ,
49127    ) ->  None :
50128
@@ -73,12 +151,22 @@ def __init__(
73151        layer_idx  =  extract_layer_index (prefix )
74152        mlp_only_layers  =  ([] if  not  hasattr (config , "mlp_only_layers" ) else 
75153                           config .mlp_only_layers )
154+         use_aclgraph  =  (vllm_config  is  not   None 
155+                         and  vllm_config .compilation_config .level 
156+                         ==  CompilationLevel .PIECEWISE 
157+                         and  not  vllm_config .model_config .enforce_eager )
76158        if  (layer_idx  not  in   mlp_only_layers ) and  (
77159                config .num_experts  >  0  and 
78160            (layer_idx  +  1 ) %  config .decoder_sparse_step  ==  0 ):
79-             self .mlp  =  AscendSparseMoeBlock (config = config ,
80-                                             quant_config = quant_config ,
81-                                             prefix = f"{ prefix }  .mlp" )
161+             if  not  use_aclgraph :
162+                 # FIXME: custom sparse moe block doesn't work with aclgraph. 
163+                 self .mlp  =  CustomSparseMoeBlock (config = config ,
164+                                                 quant_config = quant_config ,
165+                                                 prefix = f"{ prefix }  .mlp" )
166+             else :
167+                 self .mlp  =  Qwen3MoeSparseMoeBlock (config = config ,
168+                                                   quant_config = quant_config ,
169+                                                   prefix = f"{ prefix }  .mlp" )
82170        else :
83171            self .mlp  =  Qwen3MoeMLP (hidden_size = config .hidden_size ,
84172                                   intermediate_size = config .intermediate_size ,
@@ -115,6 +203,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
115203                config = config ,
116204                cache_config = cache_config ,
117205                quant_config = quant_config ,
206+                 vllm_config = vllm_config ,
118207                prefix = prefix ),
119208            prefix = f"{ prefix }  .layers" ,
120209        )
0 commit comments