Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions torchtitan/experiments/llama4/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,16 @@

llama4_configs = {
"debugmodel": TransformerModelArgs(
dim=256,
n_layers=6,
n_heads=16,
vocab_size=2000,
dim=5120,
n_layers=2,
n_heads=40,
n_kv_heads=8,
ffn_dim_multiplier=1.2,
multiple_of=2048,
rope_theta=500000,
max_seq_len=10485760,
num_experts=8,
interleave_moe_layer_step=1,
),
"17bx16e": TransformerModelArgs(
dim=5120,
Expand Down
32 changes: 20 additions & 12 deletions torchtitan/models/deepseek_v3/model/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ def __init__(
):
super().__init__()
self.num_experts = num_experts
self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
self.w1 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
self.w2 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
self.w3 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
self.use_grouped_mm = use_grouped_mm

def forward(
Expand Down Expand Up @@ -104,9 +104,9 @@ def _run_experts_for_loop(
)
out_experts_splits = []
for expert_idx, x_expert in enumerate(x):
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
h = h * torch.matmul(x_expert, w3[expert_idx])
h = torch.matmul(h, w2[expert_idx])
h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1)))
h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1))
h = torch.matmul(h, w2[expert_idx].transpose(-2, -1))
# h shape (tokens_per_expert(varying), dim)
out_experts_splits.append(h)
out = torch.cat(out_experts_splits, dim=0)
Expand All @@ -115,10 +115,10 @@ def _run_experts_for_loop(
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
else:
# x shape (num_experts, tokens_per_expert, dim)
h = F.silu(torch.bmm(x, w1))
h = h * torch.bmm(x, w3)
h = F.silu(torch.bmm(x, w1.transpose(-2, -1)))
h = h * torch.bmm(x, w3.transpose(-2, -1))
# out shape (num_experts, tokens_per_expert, dim)
out = torch.bmm(h, w2)
out = torch.bmm(h, w2.transpose(-2, -1))

return out

Expand All @@ -140,9 +140,17 @@ def _run_experts_grouped_mm(
# fall back to regular bmm between 3D tensors
assert x.dim() == 3

h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)
h = F.silu(
torch._grouped_mm(
x.bfloat16(), w1.bfloat16().transpose(-2, -1), offs=offsets
)
)
h = h * torch._grouped_mm(
x.bfloat16(), w3.bfloat16().transpose(-2, -1), offs=offsets
)
out = torch._grouped_mm(
h, w2.bfloat16().transpose(-2, -1), offs=offsets
).type_as(x)

return out

Expand Down
Loading