Skip to content

Commit de335b3

Browse files
wwwjnRohan Pandey
andauthored
gpt-oss model enablement (#1754)
Keep developing on top of #1559. Thanks @KhoomeiK for initial contribution! Initialized by the same seed checkpoint, set seed=0 and deterministic = True. GPT-oss Run 1: dp_shard = 2 <img width="1645" height="291" alt="Screenshot 2025-10-17 at 3 34 20 PM" src="https://github.com/user-attachments/assets/9876555f-7159-42d1-8765-17b62feac22c" /> Run 2: dp_shard = 2, TP degree = 2 (NGPU=4) <img width="1222" height="203" alt="Screenshot 2025-10-21 at 8 25 36 PM" src="https://github.com/user-attachments/assets/0014188a-d989-4157-8705-c3fcbab3cf44" /> Run 3: dp_shard = 2, TP degree =2, EP degree = 2 (NGPU=4) <img width="1222" height="203" alt="Screenshot 2025-10-21 at 8 27 34 PM" src="https://github.com/user-attachments/assets/b4ff5076-8c18-47cb-be06-90cf513bd7df" /> Run 4: dp_shard = 2, TP degree = 2, EP degree = 2, ETP degree = 2 (NGPU=4) <img width="1222" height="254" alt="Screenshot 2025-10-21 at 8 30 41 PM" src="https://github.com/user-attachments/assets/8a50e991-c9f2-4b95-b2cc-709acc98e67c" /> Run 5: dp_shard=2, EP degree = 2 (NGPU=2) <img width="1342" height="210" alt="Screenshot 2025-10-17 at 3 35 41 PM" src="https://github.com/user-attachments/assets/6a14a64d-5b43-4efd-b5d2-ab40e2ede52c" /> --------- Co-authored-by: Rohan Pandey <rohan@periodiclabs.ai>
1 parent e5ef99a commit de335b3

File tree

13 files changed

+1375
-7
lines changed

13 files changed

+1375
-7
lines changed

torchtitan/experiments/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ We provide this `experiments/` folder to host experiments that add significant v
2828
| [vlm](./vlm/) | [![VLM 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_vlm.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_vlm.yaml?query=branch%3Amain) | [@lkhphuc](https://github.com/lkhphuc) |
2929
| [forge](./forge/) | TBA | [@allenwang28](https://github.com/allenwang28) [@ebsmothers](https://github.com/ebsmothers) [@joecummings](https://github.com/joecummings) [@pbontrager](https://github.com/pbontrager) |
3030
| [torchcomms](./torchcomms/) | TBA | [@d4l3k](https://https://github.com/d4l3k) [@fduwjj](https://github.com/fduwjj) [@mori360 ](https://github.com/mori360) |
31-
| [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/pytorch/torchtitan/pulls/kwen2501) |
31+
| [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) |
32+
| [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) |

torchtitan/experiments/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@
55
# LICENSE file in the root directory of this source tree.
66

77
_supported_experiments = frozenset(
8-
["flux", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"]
8+
["flux", "gpt_oss", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"]
99
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# gpt-oss Model in torchtitan
2+
3+
## Quick Start
4+
```bash
5+
CONFIG_FILE="./torchtitan/experiments/gpt_oss/train_configs/debug_model.toml" ./run_train.sh
6+
```
7+
8+
## Supported Features
9+
- FSDP/HSDP, TP, EP, ETP
10+
- Grouped matrix multiplication for efficient computation
11+
12+
13+
## TODO
14+
1. More parallelism support: CP, PP
15+
2. Conversion between HF weights (StateDictAdapter)
16+
3. Forward parity verification
17+
4. CI support
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from torchtitan.components.loss import build_cross_entropy_loss
8+
from torchtitan.components.lr_scheduler import build_lr_schedulers
9+
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
10+
from torchtitan.components.tokenizer import build_hf_tokenizer
11+
from torchtitan.datasets.hf_datasets import build_hf_dataloader
12+
from torchtitan.models.moe import MoEArgs
13+
14+
from torchtitan.protocols.train_spec import TrainSpec
15+
16+
from .infra.parallelize import parallelize_gptoss
17+
from .model.args import GptOssModelArgs
18+
from .model.model import GptOssModel
19+
20+
__all__ = [
21+
"parallelize_gptoss",
22+
"GptOssModelArgs",
23+
"GptOssModel",
24+
"gptoss_configs",
25+
]
26+
27+
28+
gptoss_configs = {
29+
"debugmodel": GptOssModelArgs(
30+
dim=256,
31+
n_layers=4,
32+
moe_args=MoEArgs(
33+
num_experts=8,
34+
num_shared_experts=0,
35+
score_func="softmax",
36+
route_norm=False,
37+
route_scale=1.0,
38+
score_before_experts=False,
39+
top_k=4,
40+
use_grouped_mm=True,
41+
load_balance_coeff=1e-3,
42+
),
43+
attn_mask_type="causal",
44+
),
45+
"20b": GptOssModelArgs(
46+
n_layers=24,
47+
moe_args=MoEArgs(
48+
num_experts=32,
49+
num_shared_experts=0,
50+
score_func="softmax",
51+
route_norm=False,
52+
route_scale=1.0,
53+
score_before_experts=False,
54+
top_k=4,
55+
use_grouped_mm=True,
56+
load_balance_coeff=1e-3,
57+
),
58+
),
59+
"120b": GptOssModelArgs(
60+
n_layers=36,
61+
moe_args=MoEArgs(
62+
num_experts=128,
63+
num_shared_experts=0,
64+
score_func="softmax",
65+
route_norm=False,
66+
route_scale=1.0,
67+
score_before_experts=False,
68+
top_k=4,
69+
use_grouped_mm=True,
70+
load_balance_coeff=1e-3,
71+
),
72+
),
73+
}
74+
75+
76+
def get_train_spec() -> TrainSpec:
77+
return TrainSpec(
78+
model_cls=GptOssModel,
79+
model_args=gptoss_configs,
80+
parallelize_fn=parallelize_gptoss,
81+
pipelining_fn=None,
82+
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
83+
build_lr_schedulers_fn=build_lr_schedulers,
84+
build_dataloader_fn=build_hf_dataloader,
85+
build_tokenizer_fn=build_hf_tokenizer,
86+
build_loss_fn=build_cross_entropy_loss,
87+
)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch.nn as nn
9+
from torch.distributed.tensor import distribute_tensor, Replicate, Shard
10+
from torchtitan.distributed.expert_parallel import ExpertTensorParallel, TensorParallel
11+
12+
# implementation of Tensor Parallel for the GroupedExperts in MoE
13+
class GptossTensorParallel(TensorParallel):
14+
def _partition_fn(self, name, module, device_mesh):
15+
module.register_parameter(
16+
"mlp1_weight",
17+
nn.Parameter(
18+
distribute_tensor(module.mlp1_weight, device_mesh, [Shard(1)])
19+
),
20+
) # Column-wise sharding
21+
module.register_parameter(
22+
"mlp1_bias",
23+
nn.Parameter(distribute_tensor(module.mlp1_bias, device_mesh, [Shard(1)])),
24+
) # Column-wise sharding
25+
module.register_parameter(
26+
"mlp2_weight",
27+
nn.Parameter(
28+
distribute_tensor(module.mlp2_weight, device_mesh, [Shard(2)])
29+
),
30+
) # Row-wise sharding
31+
module.register_parameter(
32+
"mlp2_bias",
33+
nn.Parameter(
34+
distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()])
35+
),
36+
) # Replicate
37+
38+
39+
# This class is for dp2ep with TP (without TP we can just use GptossExpertParallel)
40+
class GptossExpertTensorParallel(ExpertTensorParallel):
41+
def _partition_fn_2d(self, name, mod, ep_tp_mesh):
42+
mod.register_parameter(
43+
"mlp1_weight",
44+
nn.Parameter(
45+
distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(1)])
46+
),
47+
) # Column-wise sharding
48+
mod.register_parameter(
49+
"mlp1_bias",
50+
nn.Parameter(
51+
distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)])
52+
),
53+
) # Column-wise sharding
54+
mod.register_parameter(
55+
"mlp2_weight",
56+
nn.Parameter(
57+
distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(2)])
58+
),
59+
) # Row-wise sharding
60+
mod.register_parameter(
61+
"mlp2_bias",
62+
nn.Parameter(
63+
distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Replicate()])
64+
),
65+
) # Replicate

0 commit comments

Comments
 (0)