Skip to content

Commit

Permalink
Merge branch 'node_limited_routing' into 'main'
Browse files Browse the repository at this point in the history
Support Node-Limited Routing for DeepSeek-V3

See merge request ADLR/megatron-lm!2521
  • Loading branch information
ko3n1g committed Feb 9, 2025
2 parents 5d7575d + b1022a3 commit cd4a391
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 77 deletions.
3 changes: 2 additions & 1 deletion examples/gpt3/gpt_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ language_model:
# MoE related
moe_router_load_balancing_type: "aux_loss"
moe_router_topk: 2
moe_router_topk_limited_devices: null
moe_router_group_topk: null
moe_router_num_groups: null
moe_grouped_gemm: False
moe_aux_loss_coeff: 0 # 1e-2 would be a good start value for load balance loss.
moe_z_loss_coeff: null # 1e-3 would be a good start value for z-loss
Expand Down
3 changes: 3 additions & 0 deletions megatron/core/distributed/finalize_model_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,9 @@ def _update_router_expert_bias(model: List[torch.nn.Module], config: Transformer
if hasattr(module, 'expert_bias'):
tokens_per_expert_list.append(module.local_tokens_per_expert)
expert_bias_list.append(module.expert_bias)
# For hybrid models with both MoE and Dense layers, this list can be empty.
if len(expert_bias_list) == 0:
return
stacked_tokens_per_expert = torch.stack(tokens_per_expert_list, dim=0)
stacked_expert_bias = torch.stack(expert_bias_list, dim=0)
stacked_updated_expert_bias = get_updated_expert_bias(
Expand Down
3 changes: 2 additions & 1 deletion megatron/core/transformer/moe/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ Megatron-Core offers rich parallelism mappings, combining Expert Parallelism wit
| --moe-router-topk | Number of experts to route to for each token. The default is 2. |
| --moe-router-score-function | Score function for MoE routing. Can be "softmax" or "sigmoid". Default is "softmax". |
| --moe-router-pre-softmax | Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k. |
| --moe-router-topk-limited-devices | Number of expert parallel ranks to consider for each token during routing. Perform top-k routing on a subset of expert parallel ranks by first selecting N ranks for each token, then conducting top-k selection among experts on these devices. None means no device limitation. Default is None, which means no limited devices. |
| --moe-router-num-groups | Number of groups to divide experts into for group-limited routing. When using group-limited routing: 1) Experts are divided into equal-sized groups, 2) For each token, a subset of groups are selected based on routing scores (sum of top-2 expert scores within each group), 3) From these selected groups, moe_router_topk experts are chosen. Two common use cases: 1) Device-limited routing: Set equal to expert parallel size (EP) to limit each token to experts on a subset of devices (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434) 2) Node-limited routing: Set equal to number of nodes in EP group to limit each token to experts on a subset of nodes (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)) |
| --moe-router-group-topk | Number of selected groups for group-limited routing. |
| --moe-router-topk-scaling-factor | Scaling factor for routing score in top-k selection, only works when --moe-router-pre-softmax enabled. Defaults to None, which means no scaling. |
| --moe-router-enable-expert-bias | TopK routing with dynamic per-expert bias in the aux-loss-free load balancing strategy. The routing decision is based on the sum of the routing scores and the expert bias. See https://arxiv.org/abs/2408.15664 for details. |
| --moe-router-bias-update-rate | The expert bias is updated based on the number of assigned tokens to each expert in a global batch, where the bias is increased for experts with less assigned tokens and decreased for experts with more assigned tokens. Default is 1e-3 same as that used in DeepSeekV3. |
Expand Down
81 changes: 50 additions & 31 deletions megatron/core/transformer/moe/moe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,47 +327,56 @@ def sort_chunks_by_idxs(input: torch.Tensor, split_sizes: torch.Tensor, sorted_i
return output


def device_limited_topk(
def group_limited_topk(
scores: torch.Tensor,
topk: int,
num_tokens: int,
num_experts: int,
moe_router_topk_limited_devices: int,
num_groups: int,
group_topk: int,
):
"""Perform top-k routing on a subset of expert parallel ranks.
"""Perform top-k routing on a subset of expert groups.
Selects N ranks for each token, then conducts top-k selection among experts on these devices.
See DeepSeek-V2 technical report (https://arxiv.org/pdf/2405.04434) for details.
When using group-limited routing:
1. Experts are divided into 'moe_router_num_groups' equal-sized groups
2. For each token, 'moe_router_group_topk' groups are selected based on routing scores
(specifically, the sum of top-2 expert scores within each group)
3. From these selected groups, 'moe_router_topk' individual experts are chosen
Two common use cases:
- Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP)
to limit each token to experts on a subset of devices
(See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)
- Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group
to limit each token to experts on a subset of nodes
(See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)
Args:
scores (torch.Tensor): Softmax scores from the router.
scores (torch.Tensor): Softmax scores generated by the router.
topk (int): The number of experts to select for each token.
num_tokens (int): The number of tokens.
num_experts (int): The number of experts.
moe_router_topk_limited_devices (int): Number of expert parallel ranks to consider for
each token during routing. None means no device limitation.
num_groups (int): Number of groups for routed experts.
group_topk (int): Number of groups selected for each token.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Probs and indices tensor.
"""

# Organize the experts into groups
num_group = (
parallel_state.get_expert_model_parallel_world_size()
) # num_group equals to expert parallel size
group_scores = scores.view(num_tokens, num_group, -1).max(dim=-1).values
group_idx = torch.topk(group_scores, k=moe_router_topk_limited_devices, dim=-1, sorted=False)[1]
group_scores = scores.view(num_tokens, num_groups, -1).topk(2, dim=-1)[0].sum(dim=-1)
group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)

# Mask the experts based on selection groups
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_tokens, num_group, num_experts // num_group)
.expand(num_tokens, num_groups, num_experts // num_groups)
.reshape(num_tokens, -1)
)

masked_scores = scores.masked_fill(~score_mask.bool(), 0.0)
masked_scores = scores.masked_fill(~score_mask.bool(), float('-inf'))
probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1)

return probs, top_indices
Expand All @@ -380,8 +389,9 @@ def topk_softmax_with_capacity(
pad_to_capacity: bool = False,
drop_policy: str = "probs",
use_pre_softmax: bool = False,
moe_router_topk_limited_devices: int = None,
moe_router_topk_scaling_factor: Optional[float] = None,
num_groups: Optional[int] = None,
group_topk: Optional[int] = None,
scaling_factor: Optional[float] = None,
deterministic_mode: bool = False,
score_function: str = "softmax",
expert_bias: Optional[torch.Tensor] = None,
Expand All @@ -397,11 +407,13 @@ def topk_softmax_with_capacity(
If "prob", the tokens with the lowest probabilities will be dropped.
If "position", tokens at the end of each batch will be dropped.
use_pre_softmax (bool): Whether to apply softmax before top-k selection.
moe_router_topk_limited_devices (int): Number of expert parallel ranks to consider for
each token during routing. None means no device limitation.
moe_router_topk_scaling_factor (float): Scaling factor for routing score in top-k
selection, only works when use_pre_softmax enabled.
num_groups (int): Number of groups for routed experts.
group_topk (int): Number of selected groups for each token.
scaling_factor (float): Scaling factor of routing score in top-k selection.
deterministic_mode (bool): Deprecated.
score_function (str): The score function to use. Can be either "softmax" or "sigmoid".
expert_bias (torch.Tensor): The bias added to logits for expert routing.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing
Expand All @@ -415,33 +427,40 @@ def topk_softmax_with_capacity(
assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}."
num_tokens, num_experts = logits.shape

def compute_topk(scores, topk, limited_devices=None):
if limited_devices:
return device_limited_topk(scores, topk, num_tokens, num_experts, limited_devices)
def compute_topk(scores, topk, num_groups=None, group_topk=None):
if group_topk:
return group_limited_topk(
scores=scores,
topk=topk,
num_tokens=num_tokens,
num_experts=num_experts,
num_groups=num_groups,
group_topk=group_topk,
)
else:
return torch.topk(scores, k=topk, dim=1)

if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
probs, top_indices = compute_topk(scores, topk, moe_router_topk_limited_devices)
probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
else:
scores, top_indices = compute_topk(logits, topk, moe_router_topk_limited_devices)
scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
if expert_bias is not None:
scores_for_routing = scores + expert_bias
_, top_indices = compute_topk(scores_for_routing, topk, moe_router_topk_limited_devices)
_, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
else:
scores, top_indices = compute_topk(scores, topk, moe_router_topk_limited_devices)
scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")

if moe_router_topk_scaling_factor:
probs = probs * moe_router_topk_scaling_factor
if scaling_factor:
probs = probs * scaling_factor

# TODO Try using element-wise operations instead of scatter?
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
Expand Down
14 changes: 9 additions & 5 deletions megatron/core/transformer/moe/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,9 @@ def aux_loss_load_balancing(self, logits: torch.Tensor):
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
moe_router_topk_limited_devices=self.config.moe_router_topk_limited_devices,
moe_router_topk_scaling_factor=self.config.moe_router_topk_scaling_factor,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
expert_bias=self.expert_bias,
Expand Down Expand Up @@ -200,8 +201,9 @@ def seq_aux_loss_load_balancing(self, logits: torch.Tensor, bsz: int, seq_length
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
moe_router_topk_limited_devices=self.config.moe_router_topk_limited_devices,
moe_router_topk_scaling_factor=self.config.moe_router_topk_scaling_factor,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
expert_bias=self.expert_bias,
Expand Down Expand Up @@ -329,7 +331,9 @@ def routing(self, logits: torch.Tensor):
pad_to_capacity=self.config.moe_pad_expert_input_to_capacity,
drop_policy=self.config.moe_token_drop_policy,
use_pre_softmax=self.config.moe_router_pre_softmax,
moe_router_topk_scaling_factor=self.config.moe_router_topk_scaling_factor,
num_groups=self.config.moe_router_num_groups,
group_topk=self.config.moe_router_group_topk,
scaling_factor=self.config.moe_router_topk_scaling_factor,
deterministic_mode=self.config.deterministic_mode,
score_function=self.score_function,
expert_bias=self.expert_bias,
Expand Down
57 changes: 48 additions & 9 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

import warnings
from dataclasses import dataclass
from typing import Callable, List, Optional, Tuple, Union

Expand Down Expand Up @@ -286,9 +287,28 @@ class TransformerConfig(ModelParallelConfig):
"""Number of experts to route to for each token."""

moe_router_topk_limited_devices: int = None
"""Number of expert parallel ranks to consider for each token during routing. Perform top-k
routing on a subset of expert parallel ranks by first selecting N ranks for each token, then
conducting top-k selection among experts on these devices. None means no device limitation."""
"""Number of EP ranks to consider for each token in group-limited routing,
DEPRECATED and replaced by moe_router_num_groups and moe_router_group_topk.
"""

moe_router_num_groups: int = None
"""Number of groups to divide experts into for group-limited routing.
When using group-limited routing:
1. Experts are divided into 'moe_router_num_groups' equal-sized groups
2. For each token, 'moe_router_group_topk' groups are selected based on routing scores
(specifically, the sum of top-2 expert scores within each group)
3. From these selected groups, 'moe_router_topk' individual experts are chosen
Two common use cases:
- Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP)
to limit each token to experts on a subset of devices
(See DeepSeek-V2: https://arxiv.org/pdf/2405.04434)
- Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group
to limit each token to experts on a subset of nodes
(See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)
"""

moe_router_group_topk: int = None
"""Number of selected groups for group-limited routing."""

moe_router_pre_softmax: bool = False
"""Enable pre-softmax routing for MoE, which means softmax is before the top-k selection.
Expand Down Expand Up @@ -598,7 +618,7 @@ def __post_init__(self):

if self.num_layers_in_last_pipeline_stage is not None:
if self.num_layers_in_last_pipeline_stage <= 0:
raise ValueError('num_layers_in_first_pipeline_stage must be larger than 0')
raise ValueError('num_layers_in_last_pipeline_stage must be larger than 0')

if self.virtual_pipeline_model_parallel_size is not None:
if (
Expand Down Expand Up @@ -765,13 +785,32 @@ def __post_init__(self):
# since softmax on a [num_tokens, 1] would yield a zero gradient.
raise ValueError("Please use --moe-router-pre-softmax when topk is 1.")

if self.moe_router_topk_limited_devices:
if self.moe_router_topk_limited_devices > self.expert_model_parallel_size:
if self.moe_router_group_topk:
if self.moe_router_topk_limited_devices:
raise ValueError(
f"moe_router_topk_limited_devices: {self.moe_router_topk_limited_devices} "
f"must be smaller than expert_model_parallel_size "
f"{self.expert_model_parallel_size}"
"moe_router_topk_limited_devices is deprecated and replaced by "
"moe_router_group_topk and moe_router_num_groups."
)
if not self.moe_router_num_groups:
raise ValueError(
"When using group limited routing, moe_router_num_groups must be specified."
)
else:
assert self.num_moe_experts % self.moe_router_num_groups == 0, (
f"num_moe_experts ({self.num_moe_experts}) should be divisible by "
f"moe_router_num_groups ({self.moe_router_num_groups})."
)
assert self.moe_router_group_topk <= self.moe_router_num_groups, (
f"moe_router_group_topk ({self.moe_router_group_topk}) should be smaller than "
f"moe_router_num_groups ({self.moe_router_num_groups})."
)
elif self.moe_router_topk_limited_devices:
warnings.warn(
"moe_router_topk_limited_devices is deprecated. Use moe_router_group_topk and "
"moe_router_num_groups instead."
)
self.moe_router_group_topk = self.moe_router_topk_limited_devices
self.moe_router_num_groups = self.expert_model_parallel_size

if self.flash_decode and self.fp8:
raise ValueError("FP8 inference is currently not support with flash decoding.")
Expand Down
7 changes: 5 additions & 2 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,8 +2197,11 @@ def _add_moe_args(parser):
help='Number of experts to route to for each token. The default is 2.')
group.add_argument('--moe-router-pre-softmax', action='store_true',
help='Enable pre-softmax routing for MoE, which means softmax is before the top-k selection. By default, softmax is done after top-k.')
group.add_argument('--moe-router-topk-limited-devices', type=int, default=None,
help='Number of expert parallel ranks to consider for each token during routing. Perform top-k routing on a subset of expert parallel ranks by first selecting N ranks for each token, then conducting top-k selection among experts on these devices. Default is None, which means no limited devices.')
group.add_argument('--moe-router-num-groups', type=int, default=None,
help='Number of groups to divide experts into for group-limited routing. When using group-limited routing: 1) Experts are divided into equal-sized groups, 2) For each token, a subset of groups are selected based on routing scores (sum of top-2 expert scores within each group), 3) From these selected groups, moe_router_topk experts are chosen.'
'Two common use cases: 1) Device-limited routing: Set equal to expert parallel size (EP) to limit each token to experts on a subset of devices (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434) 2) Node-limited routing: Set equal to number of nodes in EP group to limit each token to experts on a subset of nodes (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437)')
group.add_argument('--moe-router-group-topk', type=int, default=None,
help='Number of selected groups for group-limited routing.')
group.add_argument('--moe-router-topk-scaling-factor', type=float, default=None,
help='Scaling factor for routing score in top-k selection, only works when --moe-router-pre-softmax enabled. Defaults to None, which means no scaling.')
group.add_argument('--moe-router-enable-expert-bias', action='store_true',
Expand Down
Loading

0 comments on commit cd4a391

Please sign in to comment.