Skip to content

Commit 5e6863d

Browse files
fix for loop impl
1 parent 429362d commit 5e6863d

File tree

2 files changed

+16
-12
lines changed

2 files changed

+16
-12
lines changed

torchtitan/experiments/llama4/infra/expert_parallel.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,17 +57,19 @@ def _partition_fn(self, name, module, device_mesh):
5757
# w1 shape = (experts, out_dim, in_dim)
5858
module.register_parameter(
5959
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(1)]))
60-
)
60+
) # Rowwise sharding
61+
6162
# w2 shape = (experts, in_dim, out_dim)
6263
module.register_parameter(
6364
"w2",
6465
nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(2)])),
65-
)
66+
) # Columnwise sharding
67+
6668
# w3 shape = (experts, out_dim, in_dim)
6769
module.register_parameter(
6870
"w3",
6971
nn.Parameter(distribute_tensor(module.w3, device_mesh, [Shard(1)])),
70-
)
72+
) # Columnwise sharding
7173

7274
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module:
7375
return distribute_module(
@@ -230,17 +232,19 @@ def _partition_fn_2d(self, name, mod, ep_tp_mesh):
230232
mod.register_parameter(
231233
"w1",
232234
nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(1)])),
233-
)
235+
) # Rowwise sharding
236+
234237
# w2 shape = (experts, in_dim, out_dim)
235238
mod.register_parameter(
236239
"w2",
237240
nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(2)])),
238-
)
241+
) # Columnwise sharding
242+
239243
# w3 shape = (experts, out_dim, in_dim)
240244
mod.register_parameter(
241245
"w3",
242246
nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(1)])),
243-
)
247+
) # Rowwise sharding
244248

245249
def _token_combine(self, mod, routed_output, device_mesh):
246250
# token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh

torchtitan/experiments/llama4/model/moe.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ def _run_experts_for_loop(
6969
)
7070
out_experts_splits = []
7171
for expert_idx, x_expert in enumerate(x):
72-
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
73-
h = h * torch.matmul(x_expert, w3[expert_idx])
74-
h = torch.matmul(h, w2[expert_idx])
72+
h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1)))
73+
h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1))
74+
h = torch.matmul(h, w2[expert_idx].transpose(-2, -1))
7575
# h shape (tokens_per_expert(varying), dim)
7676
out_experts_splits.append(h)
7777
out = torch.cat(out_experts_splits, dim=0)
@@ -80,10 +80,10 @@ def _run_experts_for_loop(
8080
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
8181
else:
8282
# x shape (num_experts, tokens_per_expert, dim)
83-
h = F.silu(torch.bmm(x, w1))
84-
h = h * torch.bmm(x, w3)
83+
h = F.silu(torch.bmm(x, w1.transpose(-2, -1)))
84+
h = h * torch.bmm(x, w3.transpose(-2, -1))
8585
# out shape (num_experts, tokens_per_expert, dim)
86-
out = torch.bmm(h, w2)
86+
out = torch.bmm(h, w2.transpose(-2, -1))
8787

8888
return out
8989

0 commit comments

Comments
 (0)