@@ -109,13 +109,16 @@ def rebalance_experts_hierarchical(
109109 num_physical_experts: number of physical experts after replication
110110 num_groups: number of expert groups
111111 num_nodes: number of server nodes, where the intra-node network
112- (e.g, NVLink) is faster
112+ (e.g. , NVLink) is faster
113113 num_gpus: number of GPUs, must be a multiple of `num_nodes`
114114
115115 Returns:
116- physical_to_logical_map: [num_moe_layers, num_physical_experts]
117- logical_to_physical_map: [num_moe_layers, num_logical_experts, X]
118- logical_count: [num_moe_layers, num_logical_experts]
116+ physical_to_logical_map (torch.Tensor):
117+ [num_moe_layers, num_physical_experts]
118+ logical_to_physical_map (torch.Tensor):
119+ [num_moe_layers, num_logical_experts, X]
120+ logical_count (torch.Tensor):
121+ [num_moe_layers, num_logical_experts]
119122 """
120123 num_layers , num_logical_experts = weight .shape
121124 assert num_logical_experts % num_groups == 0
@@ -197,11 +200,13 @@ def rebalance_experts(
197200 num_gpus: number of GPUs, must be a multiple of `num_nodes`
198201
199202 Returns:
200- physical_to_logical_map: [layers, num_replicas], the expert index of
201- each replica
202- logical_to_physical_map: [layers, num_logical_experts, X], the replica
203- indices for each expert
204- expert_count: [layers, num_logical_experts], number of physical
203+ physical_to_logical_map:
204+ [layers, num_replicas], the expert index of each replica
205+ logical_to_physical_map:
206+ [layers, num_logical_experts, X], the replica indices for each
207+ expert
208+ expert_count:
209+ [layers, num_logical_experts], number of physical
205210 replicas for each logical expert
206211 """
207212 num_layers , num_logical_experts = weight .shape
0 commit comments