Skip to content

Commit 53332fe

Browse files
committed
[MoE/EP] apply dim-1 FSDP sharding for routed experts and rewrite shared experts with FFN
1 parent 21416c4 commit 53332fe

File tree

8 files changed

+161
-176
lines changed

8 files changed

+161
-176
lines changed

torchtitan/distributed/expert_parallel.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -320,45 +320,41 @@ def wrapper(
320320
w2: torch.Tensor,
321321
w3: torch.Tensor,
322322
x: torch.Tensor,
323-
num_tokens_per_expert: torch.Tensor | None = None,
323+
num_tokens_per_expert: torch.Tensor,
324324
) -> torch.Tensor:
325325
global TOKEN_GROUP_ALIGN_SIZE_M
326326
if isinstance(w1, DTensor):
327327
w1 = w1.to_local()
328328
w2 = w2.to_local()
329329
w3 = w3.to_local()
330330

331-
if num_tokens_per_expert is not None:
332-
from torchtitan.experiments.kernels.moe.indices import (
333-
generate_permute_indices,
331+
from torchtitan.experiments.kernels.moe.indices import generate_permute_indices
332+
333+
experts_per_ep_rank = w1.shape[0]
334+
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank
335+
336+
with torch.no_grad():
337+
(
338+
permuted_indices,
339+
num_tokens_per_expert,
340+
_, # offsets,
341+
) = generate_permute_indices(
342+
num_tokens_per_expert,
343+
experts_per_ep_rank,
344+
num_ep_ranks,
345+
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
346+
TOKEN_GROUP_ALIGN_SIZE_M,
334347
)
335348

336-
experts_per_ep_rank = w1.shape[0]
337-
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank
338-
339-
with torch.no_grad():
340-
(
341-
permuted_indices,
342-
num_tokens_per_expert,
343-
_, # offsets,
344-
) = generate_permute_indices(
345-
num_tokens_per_expert,
346-
experts_per_ep_rank,
347-
num_ep_ranks,
348-
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
349-
TOKEN_GROUP_ALIGN_SIZE_M,
350-
)
351-
352-
x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
353-
input_shape = x.shape
354-
x = x[permuted_indices, :]
349+
x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
350+
input_shape = x.shape
351+
x = x[permuted_indices, :]
355352

356353
out = func(w1, w2, w3, x, num_tokens_per_expert)
357354

358-
if num_tokens_per_expert is not None:
359-
out_unpermuted = out.new_empty(input_shape)
360-
out_unpermuted[permuted_indices, :] = out
361-
out = out_unpermuted[:-1]
355+
out_unpermuted = out.new_empty(input_shape)
356+
out_unpermuted[permuted_indices, :] = out
357+
out = out_unpermuted[:-1]
362358

363359
return out
364360

torchtitan/experiments/llama4/infra/parallelize.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def parallelize_llama(
136136
reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward,
137137
dp_mod_ep_mesh=(
138138
world_mesh[tuple(dp_mod_ep_mesh_dim_names)]
139-
if dp_mod_ep_mesh_dim_names
139+
if parallel_dims.ep_enabled
140140
else None
141141
),
142142
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
@@ -295,34 +295,43 @@ def apply_fsdp(
295295
if cpu_offload:
296296
fsdp_config["offload_policy"] = CPUOffloadPolicy()
297297

298-
for layer_id, transformer_block in model.layers.items():
299-
if reshard_after_forward_policy == "always":
298+
match reshard_after_forward_policy:
299+
case "always":
300300
reshard_after_forward = True
301-
elif reshard_after_forward_policy == "never":
301+
case "never":
302302
reshard_after_forward = False
303-
elif reshard_after_forward_policy == "default":
304-
if pp_enabled:
305-
# For PP, do not reshard after forward to avoid per-microbatch
306-
# all-gathers, which can be expensive and non-overlapped
307-
reshard_after_forward = False
308-
else:
309-
# As an optimization, do not reshard after forward for the last
310-
# transformer block since FSDP would prefetch it immediately
311-
reshard_after_forward = int(layer_id) < len(model.layers) - 1
312-
else:
303+
case "default":
304+
# For PP, by default do not reshard after forward to avoid per-microbatch
305+
# all-gathers, which can be expensive and non-overlapped
306+
reshard_after_forward = not pp_enabled
307+
case _:
313308
raise ValueError(
314309
f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}."
315310
)
316311

317-
# NOTE: in an MoE layer, the router and the shared experts
318-
# are sharded together with the TransformerBlock
312+
if model.tok_embeddings is not None:
313+
fully_shard(
314+
model.tok_embeddings,
315+
**fsdp_config,
316+
reshard_after_forward=reshard_after_forward,
317+
)
318+
319+
for layer_id, transformer_block in model.layers.items():
320+
# NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping
321+
# - the router and the shared experts are sharded together with the TransformerBlock
322+
# - the routed experts are sharded with the remaining dp_mod_ep_mesh
319323
if transformer_block.moe_enabled and dp_mod_ep_mesh:
320324
fsdp_mod_ep_config = fsdp_config.copy()
321325
fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh
322326
fully_shard(
323327
transformer_block.moe.experts,
324328
**fsdp_mod_ep_config,
325329
reshard_after_forward=reshard_after_forward,
330+
# NOTE: When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding
331+
# causes inefficiency, so we choose to do FSDP sharding on dim-1.
332+
# TODO: Even when EP is not used, we may still want to
333+
# shard the experts on non-0 dim.
334+
shard_placement_fn=lambda param: Shard(1),
326335
)
327336
# NOTE: # Although the FSDP sharding of experts is done on a mesh of
328337
# a different size than other parameters, the gradient division
@@ -336,7 +345,17 @@ def apply_fsdp(
336345
**fsdp_config,
337346
reshard_after_forward=reshard_after_forward,
338347
)
339-
fully_shard(model, **fsdp_config, reshard_after_forward=not pp_enabled)
348+
349+
# As an optimization, do not reshard_after_forward the last layers by default
350+
# since FSDP would prefetch them immediately after the forward pass
351+
if model.norm is not None and model.output is not None:
352+
fully_shard(
353+
[model.norm, model.output],
354+
**fsdp_config,
355+
reshard_after_forward=reshard_after_forward_policy == "always",
356+
)
357+
358+
fully_shard(model, **fsdp_config)
340359

341360

342361
def apply_moe_ep_tp(
@@ -362,9 +381,18 @@ def apply_moe_ep_tp(
362381
),
363382
# replicate computation for the router
364383
"moe.router.gate": NoParallel(),
365-
# input Replicate, output Partial
366-
"moe.shared_expert": TensorParallel(),
367384
}
385+
if transformer_block.moe.shared_experts is not None:
386+
# input Replicate, output Partial
387+
moe_layer_plan.update(
388+
{
389+
"moe.shared_experts.w1": ColwiseParallel(),
390+
"moe.shared_experts.w2": RowwiseParallel(
391+
output_layouts=Partial()
392+
),
393+
"moe.shared_experts.w3": ColwiseParallel(),
394+
}
395+
)
368396
parallelize_module(
369397
module=transformer_block,
370398
device_mesh=tp_mesh,

torchtitan/experiments/llama4/model/args.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,28 +85,28 @@ def get_nparams_and_flops(
8585
) -> tuple[int, float]:
8686
nparams_embedding = 0
8787
nparams_moe_router = 0
88-
nparams_shared_expert = 0
88+
nparams_shared_experts = 0
8989
nparams_experts = 0
9090
nparams_dense = 0
9191

9292
for name, p in model.named_parameters():
9393
if "embedding" in name:
9494
nparams_embedding += p.numel()
9595
nparams_dense += p.numel()
96-
elif "moe.shared_expert" in name:
97-
nparams_shared_expert += p.numel()
96+
elif "moe.shared_experts" in name:
97+
nparams_shared_experts += p.numel()
9898
elif "moe.router" in name:
9999
nparams_moe_router += p.numel()
100100
elif "moe.experts" in name:
101101
nparams_experts += p.numel()
102102
else:
103103
nparams_dense += p.numel()
104104

105-
nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
105+
nparams_sparse = nparams_moe_router + nparams_shared_experts + nparams_experts
106106
nparams = nparams_dense + nparams_sparse
107107
nparams_sparse_active = (
108108
nparams_moe_router
109-
+ nparams_shared_expert
109+
+ nparams_shared_experts
110110
+ nparams_experts * self.moe_args.top_k // self.moe_args.num_experts
111111
)
112112

torchtitan/experiments/llama4/scripts/convert_hf_to_dcp_with_gpus.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,11 @@ def convert_to_titan_fqns(fqn: str) -> list[str]:
5757
elif "feed_forward.router.weight" in fqn:
5858
return [f"layers.{layer}.moe.router.gate.weight"]
5959
elif "feed_forward.shared_expert.down_proj.weight" in fqn:
60-
return [f"layers.{layer}.moe.shared_expert.w2"]
60+
return [f"layers.{layer}.moe.shared_experts.w2"]
6161
elif "feed_forward.shared_expert.gate_proj.weight" in fqn:
62-
return [f"layers.{layer}.moe.shared_expert.w3"]
62+
return [f"layers.{layer}.moe.shared_experts.w3"]
6363
elif "feed_forward.shared_expert.up_proj.weight" in fqn:
64-
return [f"layers.{layer}.moe.shared_expert.w1"]
64+
return [f"layers.{layer}.moe.shared_experts.w1"]
6565
elif "post_attention_layernorm.weight" in fqn:
6666
return [f"layers.{layer}.ffn_norm.weight"]
6767
elif "self_attn.k_proj" in fqn:
@@ -86,7 +86,7 @@ def convert_to_hf_shape(fqn: str, titan_fqns: list[str], dtensor: DTensor) -> li
8686
elif "shared_expert" in fqn:
8787
s = dtensor.shape
8888
# TODO: this is not right but I have to do this to load the checkpoint.
89-
return torch.Size((s[2], s[1]))
89+
return torch.Size((s[1], s[0]))
9090
return dtensor.shape
9191

9292

@@ -96,7 +96,7 @@ def convert_to_titan_tensors(fqn: str, full_tensor: torch.Tensor) -> torch.Tenso
9696
elif "shared_expert" in fqn:
9797
# TODO: this is not right but I have to do this to load the checkpoint.
9898
full_tensor = full_tensor.transpose(1, 0)
99-
full_tensors = [full_tensor.unsqueeze(0)]
99+
full_tensors = [full_tensor]
100100
else:
101101
full_tensors = [full_tensor]
102102
return full_tensors

torchtitan/models/deepseek_v3/model/args.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,28 +126,28 @@ def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, in
126126
"""
127127
nparams_embedding = 0
128128
nparams_moe_router = 0
129-
nparams_shared_expert = 0
129+
nparams_shared_experts = 0
130130
nparams_experts = 0
131131
nparams_dense = 0
132132

133133
for name, p in model.named_parameters():
134134
if "embedding" in name:
135135
nparams_embedding += p.numel()
136136
nparams_dense += p.numel()
137-
elif "moe.shared_expert" in name:
138-
nparams_shared_expert += p.numel()
137+
elif "moe.shared_experts" in name:
138+
nparams_shared_experts += p.numel()
139139
elif "moe.router" in name:
140140
nparams_moe_router += p.numel()
141141
elif "moe.experts" in name:
142142
nparams_experts += p.numel()
143143
else:
144144
nparams_dense += p.numel()
145145

146-
nparams_sparse = nparams_moe_router + nparams_shared_expert + nparams_experts
146+
nparams_sparse = nparams_moe_router + nparams_shared_experts + nparams_experts
147147
nparams = nparams_dense + nparams_sparse
148148
nparams_sparse_active = (
149149
nparams_moe_router
150-
+ nparams_shared_expert
150+
+ nparams_shared_experts
151151
+ nparams_experts * self.moe_args.top_k // self.moe_args.num_experts
152152
)
153153

torchtitan/models/deepseek_v3/model/model.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,52 +8,15 @@
88
from typing import Tuple
99

1010
import torch
11-
import torch.nn.functional as F
1211
from torch import nn
1312

1413
from torchtitan.models.attention import build_attention, init_attention_mask
15-
from torchtitan.models.moe import MoE
14+
from torchtitan.models.moe import FeedForward, MoE
1615
from torchtitan.protocols.train_spec import ModelProtocol
1716

1817
from .args import DeepSeekV3ModelArgs
1918

2019

21-
class FeedForward(nn.Module):
22-
"""
23-
FeedForward module
24-
25-
Args:
26-
dim (int): Input dimension.
27-
hidden_dim (int): Hidden dimension of the feedforward layer.
28-
multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
29-
ffn_dim_multiplier (float | None): Custom multiplier for hidden dimension. Defaults to None.
30-
31-
Attributes:
32-
w1 (Linear): Linear transformation for the first layer.
33-
w2 (Linear): Linear transformation for the second layer.
34-
w3 (Linear): Linear transformation for the third layer.
35-
36-
"""
37-
38-
def __init__(
39-
self,
40-
dim: int,
41-
hidden_dim: int,
42-
):
43-
super().__init__()
44-
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
45-
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
46-
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
47-
48-
def forward(self, x: torch.Tensor) -> torch.Tensor:
49-
return self.w2(F.silu(self.w1(x)) * self.w3(x))
50-
51-
def init_weights(self, init_std: float = 0.02):
52-
nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02)
53-
for linear in (self.w2, self.w3):
54-
nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std)
55-
56-
5720
# Adapted from https://github.com/DeepSeek-ai/DeepSeek-V3/blob/main/inference/model.py#L294
5821
def precompute_freqs_cis(args: DeepSeekV3ModelArgs) -> torch.Tensor:
5922
"""

torchtitan/models/deepseek_v3/model/state_dict_adapter.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def __init__(self, model_args: DeepSeekV3ModelArgs, hf_assets_path: str | None):
4444
"model.layers.{}.mlp.experts.{}.up_proj.weight": "layers.{}.moe.experts.w3",
4545
"model.layers.{}.mlp.experts.{}.down_proj.weight": "layers.{}.moe.experts.w2",
4646
"model.layers.{}.mlp.gate.weight": "layers.{}.moe.router.gate.weight",
47-
"model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.moe.shared_expert.w1",
48-
"model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.moe.shared_expert.w3",
49-
"model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.moe.shared_expert.w2",
47+
"model.layers.{}.mlp.shared_experts.gate_proj.weight": "layers.{}.moe.shared_experts.w1",
48+
"model.layers.{}.mlp.shared_experts.up_proj.weight": "layers.{}.moe.shared_experts.w3",
49+
"model.layers.{}.mlp.shared_experts.down_proj.weight": "layers.{}.moe.shared_experts.w2",
5050
"model.layers.{}.mlp.gate.e_score_correction_bias": "layers.{}.moe.expert_bias",
5151
"model.norm.weight": "norm.weight",
5252
"lm_head.weight": "output.weight",
@@ -163,11 +163,6 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]:
163163
layer_num = re.search(r"\d+", key).group(0)
164164
new_key = to_hf_map[abstract_key]
165165
new_key = new_key.format(layer_num)
166-
167-
# torchtitan shape: (1, s[1], s[2]) -> HF shape: (s[1], s[2])
168-
if "shared_expert" in key:
169-
value = value.squeeze(0)
170-
171166
hf_state_dict[new_key] = value
172167

173168
else:
@@ -217,11 +212,6 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
217212
layer_num = re.search(r"\d+", key).group(0)
218213
new_key = self.from_hf_map[abstract_key]
219214
new_key = new_key.format(layer_num)
220-
221-
# HF shape: (s[1], s[2]) -> torchtitan shape: (1, s[1], s[2])
222-
if "shared_experts" in key:
223-
value = value.unsqueeze(0)
224-
225215
state_dict[new_key] = value
226216

227217
else:

0 commit comments

Comments
 (0)