1717# This file is a part of the vllm-ascend project.
1818from typing import Optional
1919
20+ import torch
2021from torch import nn
2122from transformers import PretrainedConfig
2223from vllm .compilation .decorators import support_torch_compile
23- from vllm .config import CacheConfig
24+ from vllm .config import CacheConfig , CompilationLevel , VllmConfig
25+ from vllm .distributed import get_tensor_model_parallel_world_size
26+ from vllm .distributed .parallel_state import (get_dp_group , get_ep_group ,
27+ get_tp_group )
28+ from vllm .forward_context import get_forward_context
2429from vllm .model_executor .layers .layernorm import RMSNorm
30+ from vllm .model_executor .layers .linear import ReplicatedLinear
2531from vllm .model_executor .layers .logits_processor import LogitsProcessor
2632from vllm .model_executor .layers .quantization import QuantizationConfig
2733from vllm .model_executor .layers .vocab_parallel_embedding import (
2834 ParallelLMHead , VocabParallelEmbedding )
2935from vllm .model_executor .models .qwen3_moe import (Qwen3MoeAttention ,
3036 Qwen3MoeDecoderLayer ,
3137 Qwen3MoeForCausalLM ,
32- Qwen3MoeMLP , Qwen3MoeModel )
38+ Qwen3MoeMLP , Qwen3MoeModel ,
39+ Qwen3MoeSparseMoeBlock )
3340from vllm .model_executor .models .utils import (
3441 extract_layer_index , make_empty_intermediate_tensors_factory , make_layers ,
3542 maybe_prefix )
3643
37- from vllm_ascend .ops .fused_moe import AscendSparseMoeBlock
38- from vllm_ascend .platform import VllmConfig
44+ from vllm_ascend .ops .fused_moe import AscendFusedMoE
45+
46+
47+ class CustomSparseMoeBlock (Qwen3MoeSparseMoeBlock ):
48+
49+ def __init__ (
50+ self ,
51+ config : PretrainedConfig ,
52+ quant_config : Optional [QuantizationConfig ] = None ,
53+ prefix : str = "" ,
54+ ):
55+ nn .Module .__init__ (self )
56+ self .tp_size = get_tensor_model_parallel_world_size ()
57+ if self .tp_size > config .num_experts :
58+ raise ValueError (
59+ f"Tensor parallel size { self .tp_size } is greater than "
60+ f"the number of experts { config .num_experts } ." )
61+
62+ self .gate = ReplicatedLinear (
63+ config .hidden_size ,
64+ config .num_experts ,
65+ bias = False ,
66+ quant_config = None ,
67+ prefix = f"{ prefix } .gate" ,
68+ )
69+
70+ self .experts = AscendFusedMoE (
71+ num_experts = config .num_experts ,
72+ top_k = config .num_experts_per_tok ,
73+ hidden_size = config .hidden_size ,
74+ intermediate_size = config .moe_intermediate_size ,
75+ reduce_results = False ,
76+ renormalize = config .norm_topk_prob ,
77+ quant_config = quant_config ,
78+ prefix = f"{ prefix } .experts" ,
79+ )
80+
81+ self .top_k = config .num_experts_per_tok
82+
83+ self .dp_size = get_dp_group ().world_size
84+
85+ self .tp_group = get_tp_group ().device_group
86+ self .tp_rank = get_tp_group ().rank_in_group
87+ self .ep_group = get_ep_group ()
88+
89+ self .params_dtype = torch .get_default_dtype ()
90+
91+ def forward (
92+ self ,
93+ hidden_states ,
94+ attn_metadata = None ,
95+ ):
96+ if attn_metadata is None :
97+ attn_metadata = get_forward_context ().attn_metadata
98+ # when profile runs, force experts to load balanced tokens
99+ # to avoid high memory consumption on a single rank.
100+ enable_force_load_balance = get_forward_context ().in_profile_run
101+ is_prefill = get_forward_context ().with_prefill
102+
103+ # router_logits: (num_tokens, n_experts)
104+ router_logits , _ = self .gate (hidden_states )
105+
106+ hidden_states = self .experts (
107+ hidden_states = hidden_states ,
108+ router_logits = router_logits ,
109+ is_prefill = is_prefill ,
110+ top_k = self .top_k ,
111+ enable_force_load_balance = enable_force_load_balance ,
112+ shared_experts = None ,
113+ )
114+
115+ return hidden_states
39116
40117
41118class CustomQwen3MoeDecoderLayer (Qwen3MoeDecoderLayer ):
@@ -45,6 +122,7 @@ def __init__(
45122 config : PretrainedConfig ,
46123 cache_config : Optional [CacheConfig ] = None ,
47124 quant_config : Optional [QuantizationConfig ] = None ,
125+ vllm_config : Optional [VllmConfig ] = None ,
48126 prefix : str = "" ,
49127 ) -> None :
50128
@@ -73,12 +151,22 @@ def __init__(
73151 layer_idx = extract_layer_index (prefix )
74152 mlp_only_layers = ([] if not hasattr (config , "mlp_only_layers" ) else
75153 config .mlp_only_layers )
154+ use_aclgraph = (vllm_config is not None
155+ and vllm_config .compilation_config .level
156+ == CompilationLevel .PIECEWISE
157+ and not vllm_config .model_config .enforce_eager )
76158 if (layer_idx not in mlp_only_layers ) and (
77159 config .num_experts > 0 and
78160 (layer_idx + 1 ) % config .decoder_sparse_step == 0 ):
79- self .mlp = AscendSparseMoeBlock (config = config ,
80- quant_config = quant_config ,
81- prefix = f"{ prefix } .mlp" )
161+ if not use_aclgraph :
162+ # FIXME: custom sparse moe block doesn't work with aclgraph.
163+ self .mlp = CustomSparseMoeBlock (config = config ,
164+ quant_config = quant_config ,
165+ prefix = f"{ prefix } .mlp" )
166+ else :
167+ self .mlp = Qwen3MoeSparseMoeBlock (config = config ,
168+ quant_config = quant_config ,
169+ prefix = f"{ prefix } .mlp" )
82170 else :
83171 self .mlp = Qwen3MoeMLP (hidden_size = config .hidden_size ,
84172 intermediate_size = config .intermediate_size ,
@@ -115,6 +203,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
115203 config = config ,
116204 cache_config = cache_config ,
117205 quant_config = quant_config ,
206+ vllm_config = vllm_config ,
118207 prefix = prefix ),
119208 prefix = f"{ prefix } .layers" ,
120209 )
0 commit comments