2525from vllm .model_executor .custom_op import CustomOp
2626from vllm .model_executor .layers .fused_moe .rocm_aiter_fused_moe import (
2727 is_rocm_aiter_moe_enabled )
28+ from vllm .model_executor .layers .fused_moe .utils import (
29+ collect_expert_usage_histogram )
2830from vllm .model_executor .layers .quantization .base_config import (
2931 QuantizationConfig , QuantizeMethodBase )
32+ from vllm .model_executor .models .utils import extract_layer_index
3033from vllm .model_executor .utils import set_weight_attrs
3134from vllm .platforms import current_platform
3235from vllm .platforms .interface import CpuArchEnum
@@ -415,6 +418,7 @@ def apply(
415418 router_logits : torch .Tensor ,
416419 top_k : int ,
417420 renormalize : bool ,
421+ layer_index : int ,
418422 use_grouped_topk : bool = False ,
419423 topk_group : Optional [int ] = None ,
420424 num_expert_group : Optional [int ] = None ,
@@ -554,6 +558,7 @@ def apply(
554558 router_logits : torch .Tensor ,
555559 top_k : int ,
556560 renormalize : bool ,
561+ layer_index : int ,
557562 use_grouped_topk : bool = False ,
558563 topk_group : Optional [int ] = None ,
559564 num_expert_group : Optional [int ] = None ,
@@ -571,6 +576,7 @@ def apply(
571576 router_logits = router_logits ,
572577 top_k = top_k ,
573578 renormalize = renormalize ,
579+ layer_index = layer_index ,
574580 use_grouped_topk = use_grouped_topk ,
575581 topk_group = topk_group ,
576582 num_expert_group = num_expert_group ,
@@ -590,6 +596,7 @@ def forward_cuda(
590596 top_k : int ,
591597 router_logits : torch .Tensor ,
592598 renormalize : bool ,
599+ layer_index : int ,
593600 topk_group : Optional [int ] = None ,
594601 num_expert_group : Optional [int ] = None ,
595602 global_num_experts : int = - 1 ,
@@ -607,6 +614,7 @@ def forward_cuda(
607614 use_grouped_topk = use_grouped_topk ,
608615 top_k = top_k ,
609616 renormalize = renormalize ,
617+ layer_index = layer_index ,
610618 topk_group = topk_group ,
611619 num_expert_group = num_expert_group ,
612620 custom_routing_function = custom_routing_function ,
@@ -646,6 +654,7 @@ def forward_cpu(
646654 top_k : int ,
647655 router_logits : torch .Tensor ,
648656 renormalize : bool ,
657+ layer_index : int ,
649658 topk_group : Optional [int ] = None ,
650659 num_expert_group : Optional [int ] = None ,
651660 global_num_experts : int = - 1 ,
@@ -680,6 +689,7 @@ def forward_hpu(
680689 top_k : int ,
681690 router_logits : torch .Tensor ,
682691 renormalize : bool ,
692+ layer_index : int ,
683693 topk_group : Optional [int ] = None ,
684694 num_expert_group : Optional [int ] = None ,
685695 global_num_experts : int = - 1 ,
@@ -713,6 +723,7 @@ def forward_tpu(
713723 top_k : int ,
714724 router_logits : torch .Tensor ,
715725 renormalize : bool ,
726+ layer_index : int ,
716727 topk_group : Optional [int ] = None ,
717728 num_expert_group : Optional [int ] = None ,
718729 global_num_experts : int = - 1 ,
@@ -861,6 +872,8 @@ def __init__(
861872 compilation_config .static_forward_context [prefix ] = self
862873 self .layer_name = prefix
863874
875+ self .layer_index = extract_layer_index (prefix )
876+
864877 # Determine expert maps
865878 if self .use_ep :
866879 self .local_num_experts , self .expert_map = determine_expert_map (
@@ -1282,6 +1295,7 @@ def select_experts(hidden_states: torch.Tensor,
12821295 top_k : int ,
12831296 use_grouped_topk : bool ,
12841297 renormalize : bool ,
1298+ layer_index : int ,
12851299 topk_group : Optional [int ] = None ,
12861300 num_expert_group : Optional [int ] = None ,
12871301 custom_routing_function : Optional [Callable ] = None ,
@@ -1322,6 +1336,12 @@ def select_experts(hidden_states: torch.Tensor,
13221336 if indices_type is not None :
13231337 topk_ids = topk_ids .to (dtype = indices_type )
13241338
1339+ expert_usage_histogram = get_forward_context ().expert_usage_histogram
1340+
1341+ if expert_usage_histogram is not None :
1342+ collect_expert_usage_histogram (topk_ids ,
1343+ expert_usage_histogram [layer_index ])
1344+
13251345 return topk_weights , topk_ids
13261346
13271347 def must_reduce_shared_expert_outputs (self ) -> bool :
@@ -1354,10 +1374,12 @@ def maybe_all_reduce_tensor_model_parallel(
13541374 def forward (self , hidden_states : torch .Tensor ,
13551375 router_logits : torch .Tensor ):
13561376 if self .use_direct_call :
1357- return self .forward_impl (hidden_states , router_logits )
1377+ return self .forward_impl (hidden_states , router_logits ,
1378+ self .layer_index )
13581379 else :
13591380 return torch .ops .vllm .moe_forward (hidden_states , router_logits ,
1360- self .layer_name )
1381+ self .layer_name ,
1382+ self .layer_index )
13611383
13621384 def forward_impl_chunked (self , full_hidden_states : torch .Tensor ,
13631385 full_router_logits : torch .Tensor ):
@@ -1396,6 +1418,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
13961418 router_logits = staged_router_logits ,
13971419 top_k = self .top_k ,
13981420 renormalize = self .renormalize ,
1421+ layer_index = self .layer_index ,
13991422 use_grouped_topk = self .use_grouped_topk ,
14001423 global_num_experts = self .global_num_experts ,
14011424 expert_map = self .expert_map ,
@@ -1432,7 +1455,7 @@ def process_chunk(chunk_start, chunk_end, skip_result_store=False):
14321455 return full_final_hidden_states
14331456
14341457 def forward_impl (self , hidden_states : torch .Tensor ,
1435- router_logits : torch .Tensor ):
1458+ router_logits : torch .Tensor , layer_index : int ):
14361459 assert self .quant_method is not None
14371460 if (self .moe_parallel_config .use_pplx_kernels
14381461 or self .moe_parallel_config .use_deepep_ll_kernels ):
@@ -1452,6 +1475,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
14521475 router_logits = router_logits ,
14531476 top_k = self .top_k ,
14541477 renormalize = self .renormalize ,
1478+ layer_index = layer_index ,
14551479 use_grouped_topk = self .use_grouped_topk ,
14561480 global_num_experts = self .global_num_experts ,
14571481 expert_map = self .expert_map ,
@@ -1514,16 +1538,16 @@ def extra_repr(self) -> str:
15141538
15151539
15161540def moe_forward (hidden_states : torch .Tensor , router_logits : torch .Tensor ,
1517- layer_name : str ) -> torch .Tensor :
1541+ layer_name : str , layer_index : int ) -> torch .Tensor :
15181542 forward_context : ForwardContext = get_forward_context ()
15191543 self = forward_context .no_compile_layers [layer_name ]
15201544 assert self .quant_method is not None
15211545
1522- return self .forward_impl (hidden_states , router_logits )
1546+ return self .forward_impl (hidden_states , router_logits , layer_index )
15231547
15241548
15251549def moe_forward_fake (hidden_states : torch .Tensor , router_logits : torch .Tensor ,
1526- layer_name : str ) -> torch .Tensor :
1550+ layer_name : str , layer_index : int ) -> torch .Tensor :
15271551 return torch .empty_like (hidden_states )
15281552
15291553
0 commit comments