4444 is_rocm_aiter_fusion_shared_expert_enabled ,
4545 is_rocm_aiter_moe_enabled ,
4646)
47+ from vllm .model_executor .layers .fused_moe .utils import (
48+ collect_expert_usage_histogram )
4749from vllm .model_executor .layers .fused_moe .routing_simulator import RoutingSimulator
4850from vllm .model_executor .layers .quantization .base_config import (
4951 QuantizationConfig ,
@@ -298,6 +300,7 @@ def apply(
298300 router_logits : torch .Tensor ,
299301 top_k : int ,
300302 renormalize : bool ,
303+ layer_index : int ,
301304 use_grouped_topk : bool = False ,
302305 topk_group : int | None = None ,
303306 num_expert_group : int | None = None ,
@@ -534,6 +537,7 @@ def apply(
534537 router_logits : torch .Tensor ,
535538 top_k : int ,
536539 renormalize : bool ,
540+ layer_index : int ,
537541 use_grouped_topk : bool = False ,
538542 topk_group : int | None = None ,
539543 num_expert_group : int | None = None ,
@@ -598,6 +602,7 @@ def forward_cuda(
598602 top_k : int ,
599603 router_logits : torch .Tensor ,
600604 renormalize : bool ,
605+ layer_index : int ,
601606 topk_group : int | None = None ,
602607 num_expert_group : int | None = None ,
603608 global_num_experts : int = - 1 ,
@@ -709,6 +714,7 @@ def forward_cpu(
709714 top_k : int ,
710715 router_logits : torch .Tensor ,
711716 renormalize : bool ,
717+ layer_index : int ,
712718 topk_group : int | None = None ,
713719 num_expert_group : int | None = None ,
714720 global_num_experts : int = - 1 ,
@@ -758,6 +764,7 @@ def forward_xpu(
758764 top_k : int ,
759765 router_logits : torch .Tensor ,
760766 renormalize : bool ,
767+ layer_index : int ,
761768 topk_group : int | None = None ,
762769 num_expert_group : int | None = None ,
763770 global_num_experts : int = - 1 ,
@@ -799,6 +806,7 @@ def forward_tpu(
799806 top_k : int ,
800807 router_logits : torch .Tensor ,
801808 renormalize : bool ,
809+ layer_index : int ,
802810 topk_group : int | None = None ,
803811 num_expert_group : int | None = None ,
804812 global_num_experts : int = - 1 ,
@@ -1132,6 +1140,11 @@ def __init__(
11321140 self .logical_to_physical_map : torch .Tensor | None = None
11331141 self .logical_replica_count : torch .Tensor | None = None
11341142
1143+ from vllm .model_executor .models .utils import extract_layer_index
1144+ self .layer_index = extract_layer_index (
1145+ prefix ) - vllm_config .model_config .get_total_num_dense_moe_layers (
1146+ )
1147+
11351148 # ROCm aiter shared experts fusion
11361149 self .num_fused_shared_experts = (
11371150 n_shared_experts
@@ -1936,6 +1949,7 @@ def select_experts(
19361949 top_k : int ,
19371950 use_grouped_topk : bool ,
19381951 renormalize : bool ,
1952+ layer_index : int ,
19391953 topk_group : int | None = None ,
19401954 num_expert_group : int | None = None ,
19411955 custom_routing_function : Callable | None = None ,
@@ -2067,6 +2081,13 @@ def select_experts(
20672081 )
20682082 else :
20692083 zero_expert_result = None
2084+
2085+ expert_usage_histogram = get_forward_context ().expert_usage_histogram
2086+
2087+ if expert_usage_histogram is not None :
2088+ collect_expert_usage_histogram (topk_ids ,
2089+ expert_usage_histogram [layer_index ])
2090+
20702091 return topk_weights , topk_ids , zero_expert_result
20712092
20722093 def must_reduce_shared_expert_outputs (self ) -> bool :
@@ -2115,23 +2136,25 @@ def forward_native(
21152136 if current_platform .is_tpu ():
21162137 # TODO: Once the OOM issue for the TPU backend is resolved, we
21172138 # will switch to using the moe_forward custom op.
2118- fused_output = self .forward_impl (hidden_states , router_logits )
2139+ fused_output = self .forward_impl (hidden_states , router_logits ,
2140+ self .layer_index )
21192141 assert not isinstance (fused_output , tuple )
21202142 else :
21212143 fused_output = torch .ops .vllm .moe_forward (
2122- hidden_states , router_logits , self .layer_name
2144+ hidden_states , router_logits , self .layer_name ,
2145+ self .layer_index
21232146 )
21242147 return fused_output [..., :og_hidden_states ]
21252148 else :
21262149 if current_platform .is_tpu ():
21272150 # TODO: Once the OOM issue for the TPU backend is resolved, we
21282151 # will switch to using the moe_forward custom op.
21292152 shared_output , fused_output = self .forward_impl (
2130- hidden_states , router_logits
2153+ hidden_states , router_logits , self . layer_index
21312154 )
21322155 else :
21332156 shared_output , fused_output = torch .ops .vllm .moe_forward_shared (
2134- hidden_states , router_logits , self .layer_name
2157+ hidden_states , router_logits , self .layer_name , self . layer_index
21352158 )
21362159 return (
21372160 shared_output [..., :og_hidden_states ],
@@ -2212,6 +2235,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
22122235 router_logits = staged_router_logits ,
22132236 top_k = self .top_k ,
22142237 renormalize = self .renormalize ,
2238+ layer_index = self .layer_index ,
22152239 use_grouped_topk = self .use_grouped_topk ,
22162240 global_num_experts = self .global_num_experts ,
22172241 expert_map = self .expert_map
@@ -2297,6 +2321,7 @@ def forward_impl(
22972321 self ,
22982322 hidden_states : torch .Tensor ,
22992323 router_logits : torch .Tensor ,
2324+ layer_index : int ,
23002325 ) -> torch .Tensor | tuple [torch .Tensor , torch .Tensor ]:
23012326 assert self .quant_method is not None
23022327
@@ -2339,6 +2364,7 @@ def forward_impl(
23392364 router_logits = router_logits ,
23402365 top_k = self .top_k ,
23412366 renormalize = self .renormalize ,
2367+ layer_index = layer_index ,
23422368 use_grouped_topk = self .use_grouped_topk ,
23432369 global_num_experts = self .global_num_experts ,
23442370 expert_map = self .expert_map
@@ -2459,17 +2485,19 @@ def moe_forward(
24592485 hidden_states : torch .Tensor ,
24602486 router_logits : torch .Tensor ,
24612487 layer_name : str ,
2488+ layer_index : int ,
24622489) -> torch .Tensor :
24632490 forward_context : ForwardContext = get_forward_context ()
24642491 self = forward_context .no_compile_layers [layer_name ]
24652492 assert self .shared_experts is None
2466- return self .forward_impl (hidden_states , router_logits )
2493+ return self .forward_impl (hidden_states , router_logits , layer_index )
24672494
24682495
24692496def moe_forward_fake (
24702497 hidden_states : torch .Tensor ,
24712498 router_logits : torch .Tensor ,
24722499 layer_name : str ,
2500+ layer_index : int ,
24732501) -> torch .Tensor :
24742502 return torch .empty_like (hidden_states )
24752503
@@ -2487,17 +2515,19 @@ def moe_forward_shared(
24872515 hidden_states : torch .Tensor ,
24882516 router_logits : torch .Tensor ,
24892517 layer_name : str ,
2518+ layer_index : int ,
24902519) -> tuple [torch .Tensor , torch .Tensor ]:
24912520 forward_context : ForwardContext = get_forward_context ()
24922521 self = forward_context .no_compile_layers [layer_name ]
24932522 assert self .shared_experts is not None
2494- return self .forward_impl (hidden_states , router_logits )
2523+ return self .forward_impl (hidden_states , router_logits , layer_index )
24952524
24962525
24972526def moe_forward_shared_fake (
24982527 hidden_states : torch .Tensor ,
24992528 router_logits : torch .Tensor ,
25002529 layer_name : str ,
2530+ layer_index : int ,
25012531) -> tuple [torch .Tensor , torch .Tensor ]:
25022532 shared_out = torch .empty_like (hidden_states )
25032533 fused_out = torch .empty_like (hidden_states )
0 commit comments