3232from typing import Optional , Union
3333
3434import torch
35- from torch .distributed import ProcessGroup , all_gather , all_reduce
35+ from torch .distributed import ProcessGroup , all_reduce
3636
3737from vllm .config import ParallelConfig
3838from vllm .distributed .parallel_state import (get_ep_group , get_node_count ,
@@ -112,13 +112,21 @@ class EplbState:
112112 Expert load during this forward pass.
113113 We use the token count each expert processes as the load.
114114
115- Shape: (num_moe_layers, num_local_physical_experts )
115+ Shape: (num_moe_layers, num_physical_experts )
116116 """
117117 expert_load_window : torch .Tensor
118118 """
119119 A sliding window of expert load.
120120
121- Shape: (window_size, num_moe_layers, num_local_physical_experts)
121+ Shape: (window_size, num_moe_layers, num_physical_experts)
122+
123+ NOTE: The expert_load_view now records load for all physical experts
124+ rather than just local experts. This ensures consistent load statistics
125+ across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
126+ The recorded load will be multiplied by dp_size when using naive all-to-all
127+ due to each DP rank contributing the same token set to the calculation.
128+ See:
129+ https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
122130 """
123131 expert_load_window_step : int = 0
124132 """
@@ -232,14 +240,14 @@ def build(
232240 ).contiguous ()
233241
234242 expert_load_pass = torch .zeros (
235- (model .num_moe_layers , model .num_local_physical_experts ),
243+ (model .num_moe_layers , model .num_physical_experts ),
236244 dtype = torch .int32 ,
237245 device = device ,
238246 )
239247 expert_load_window_size = parallel_config .eplb_window_size
240248 expert_load_window = torch .zeros (
241249 (expert_load_window_size , model .num_moe_layers ,
242- model .num_local_physical_experts ),
250+ model .num_physical_experts ),
243251 dtype = torch .int32 ,
244252 device = device ,
245253 )
@@ -353,18 +361,18 @@ def step(self,
353361 self .expert_load_pass .zero_ ()
354362
355363 if log_stats :
356- # `num_tokens` : (num_moe_layers,)
357- num_tokens = self .expert_load_pass .sum ( dim = - 1 )
364+ # total_expert_load_pass : (num_moe_layers, num_physical_experts )
365+ total_expert_load_pass = self .expert_load_pass .clone ( )
358366
359367 # Collect load metrics from all ranks
360368 ep_group = get_ep_group ().device_group
361369 assert ep_group is not None
362- num_tokens_list = [
363- torch . empty_like ( num_tokens ) for _ in range ( ep_group . size ())
364- ]
365- all_gather ( num_tokens_list , num_tokens , group = ep_group )
366- # Stack to get (num_ranks, num_moe_layers)
367- num_tokens_per_rank = torch . stack ( num_tokens_list ).float ()
370+ all_reduce ( total_expert_load_pass , group = ep_group )
371+
372+ # num_tokens_per_rank: (num_moe_layers, num_ranks)
373+ num_tokens_per_rank = total_expert_load_pass . reshape (
374+ total_expert_load_pass . shape [ 0 ], ep_group . size (),
375+ - 1 ). sum ( dim = - 1 ).float ()
368376
369377 # Compute balancedness ratio:
370378 # for each layer:
@@ -426,17 +434,7 @@ def rearrange(self,
426434 "(profile)" if is_profile else "" )
427435
428436 if global_expert_load is None :
429- # This mapping is only used here, so we do not store it in the state
430- physical_expert_start = ep_rank * model .num_local_physical_experts
431- physical_expert_end = (physical_expert_start +
432- model .num_local_physical_experts )
433- # (num_moe_layers, num_local_physical_experts)
434- local_physical_to_logical_map = self .physical_to_logical_map [
435- :,
436- physical_expert_start :physical_expert_end ,
437- ]
438-
439- # Map the local physical expert load to global logical experts
437+ # Map the physical expert load to global logical experts
440438 logical_expert_load_window = torch .zeros (
441439 self .expert_load_window_size ,
442440 model .num_moe_layers ,
@@ -446,7 +444,7 @@ def rearrange(self,
446444 )
447445 logical_expert_load_window .scatter_add_ (
448446 dim = - 1 ,
449- index = local_physical_to_logical_map .unsqueeze (0 ).expand_as (
447+ index = self . physical_to_logical_map .unsqueeze (0 ).expand_as (
450448 self .expert_load_window ).long (),
451449 src = self .expert_load_window ,
452450 )
@@ -618,4 +616,4 @@ def _node_count_with_rank_mapping(
618616 if is_same_node and node_assignment [other_rank ] == 0 :
619617 node_assignment [other_rank ] = next_node_id
620618
621- return next_node_id
619+ return next_node_id
0 commit comments