Skip to content

Commit 5fa7db2

Browse files
test updates
1 parent e47baf7 commit 5fa7db2

File tree

4 files changed

+118
-59
lines changed

4 files changed

+118
-59
lines changed

test/prototype/moe_training/test_fsdp.py

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torch import distributed as dist
2727
from torch import nn
2828
from torch.distributed._composable.fsdp import fully_shard
29+
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
2930
from torch.nn import functional as F
3031

3132
# this feature requires CUDA and SM89+
@@ -53,6 +54,26 @@
5354
)
5455

5556

57+
@pytest.fixture(scope="module")
58+
def device_mesh_1d() -> DeviceMesh:
59+
"""
60+
Fixture for setting up and tearing down the distributed environment
61+
for the entire test module.
62+
"""
63+
rank = int(os.environ["RANK"])
64+
world_size = int(os.environ["WORLD_SIZE"])
65+
if not dist.is_initialized():
66+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
67+
68+
device_mesh = init_device_mesh("cuda", (world_size,))
69+
torch.manual_seed(1)
70+
torch.cuda.set_device(rank)
71+
72+
yield device_mesh
73+
74+
dist.destroy_process_group()
75+
76+
5677
@pytest.mark.parametrize(
5778
"target_fqns",
5879
[
@@ -80,7 +101,12 @@
80101
},
81102
],
82103
)
83-
def test_moe_training_fsdp(target_fqns: list[str], compile: bool, recipe_config: dict):
104+
def test_moe_training_fsdp(
105+
target_fqns: list[str],
106+
compile: bool,
107+
recipe_config: dict,
108+
device_mesh_1d: DeviceMesh,
109+
):
84110
(
85111
recipe,
86112
group_alignment_size,
@@ -111,9 +137,6 @@ def test_moe_training_fsdp(target_fqns: list[str], compile: bool, recipe_config:
111137
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
112138
)
113139

114-
# setup distributed for fsdp
115-
setup_distributed()
116-
117140
# set token group alignment size needed for GEMM (contraction dim stride must be 16 byte aligned)
118141
# or quantization ops (mxfp8 scaling groups are size 1x32)
119142
set_token_group_alignment_size_m(group_alignment_size)
@@ -154,6 +177,10 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
154177
model,
155178
target_fqns=target_fqns,
156179
)
180+
if compile:
181+
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
182+
model = torch.compile(model, fullgraph=False)
183+
ref_model = torch.compile(ref_model, fullgraph=False)
157184

158185
# FSDP2
159186
fully_shard(model)
@@ -197,12 +224,3 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
197224
assert param_grad_sqnr.item() >= min_param_grad_sqnr, (
198225
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
199226
)
200-
201-
dist.destroy_process_group()
202-
203-
204-
def setup_distributed():
205-
rank = int(os.environ["RANK"])
206-
world_size = int(os.environ["WORLD_SIZE"])
207-
dist.init_process_group("nccl", rank=rank, world_size=world_size)
208-
torch.cuda.set_device(rank)

test/prototype/moe_training/test_fsdp_tp.py

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,31 @@
7474
)
7575

7676

77+
@pytest.fixture(scope="module")
78+
def device_mesh_2d() -> DeviceMesh:
79+
"""
80+
Fixture for setting up and tearing down the distributed environment
81+
for the entire test module.
82+
"""
83+
rank = int(os.environ["RANK"])
84+
world_size = int(os.environ["WORLD_SIZE"])
85+
if not dist.is_initialized():
86+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
87+
88+
device_mesh = init_device_mesh(
89+
"cuda",
90+
(world_size // 2, 2),
91+
mesh_dim_names=("dp", "tp"),
92+
)
93+
94+
torch.manual_seed(1)
95+
torch.cuda.set_device(rank)
96+
97+
yield device_mesh
98+
99+
dist.destroy_process_group()
100+
101+
77102
@pytest.mark.parametrize(
78103
"target_fqns",
79104
[
@@ -102,7 +127,10 @@
102127
],
103128
)
104129
def test_moe_training_fsdp_tp(
105-
target_fqns: list[str], compile: bool, recipe_config: dict
130+
target_fqns: list[str],
131+
compile: bool,
132+
recipe_config: dict,
133+
device_mesh_2d: DeviceMesh,
106134
):
107135
(
108136
recipe,
@@ -138,9 +166,6 @@ def test_moe_training_fsdp_tp(
138166
# or quantization ops (mxfp8 scaling groups are size 1x32)
139167
set_token_group_alignment_size_m(group_alignment_size)
140168

141-
# setup device mesh for fsdp + tp
142-
mesh = setup_distributed()
143-
144169
# define model args
145170
model_args = MoEArgs(
146171
num_experts=8,
@@ -177,13 +202,19 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
177202
model,
178203
target_fqns=target_fqns,
179204
)
205+
if compile:
206+
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
207+
model = torch.compile(model, fullgraph=False)
208+
ref_model = torch.compile(ref_model, fullgraph=False)
180209

181210
# apply TP
182-
apply_moe_ep_tp(model, tp_mesh=mesh["tp"], ep_mesh=None, ep_tp_mesh=None)
183-
apply_moe_ep_tp(ref_model, tp_mesh=mesh["tp"], ep_mesh=None, ep_tp_mesh=None)
211+
apply_moe_ep_tp(model, tp_mesh=device_mesh_2d["tp"], ep_mesh=None, ep_tp_mesh=None)
212+
apply_moe_ep_tp(
213+
ref_model, tp_mesh=device_mesh_2d["tp"], ep_mesh=None, ep_tp_mesh=None
214+
)
184215

185216
# apply FSDP2
186-
fsdp_config = {"mesh": mesh["dp"]}
217+
fsdp_config = {"mesh": device_mesh_2d["dp"]}
187218
fully_shard(model, **fsdp_config)
188219
fully_shard(ref_model, **fsdp_config)
189220

@@ -246,26 +277,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
246277
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
247278
)
248279

249-
dist.destroy_process_group()
250-
251-
252-
def setup_distributed():
253-
rank = int(os.environ["RANK"])
254-
world_size = int(os.environ["WORLD_SIZE"])
255-
dist.init_process_group("nccl", rank=rank, world_size=world_size)
256-
257-
# https://pytorch.org/tutorials/recipes/distributed_device_mesh.html
258-
device_mesh = init_device_mesh(
259-
"cuda",
260-
(world_size // 2, 2),
261-
mesh_dim_names=("dp", "tp"),
262-
)
263-
264-
# seed must be the same in all processes
265-
torch.manual_seed(1)
266-
torch.cuda.set_device(rank)
267-
return device_mesh
268-
269280

270281
def apply_moe_ep_tp(
271282
model: nn.Module,

test/prototype/moe_training/test_tp.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,26 @@
7272
)
7373

7474

75+
@pytest.fixture(scope="module")
76+
def device_mesh_1d() -> DeviceMesh:
77+
"""
78+
Fixture for setting up and tearing down the distributed environment
79+
for the entire test module.
80+
"""
81+
rank = int(os.environ["RANK"])
82+
world_size = int(os.environ["WORLD_SIZE"])
83+
if not dist.is_initialized():
84+
dist.init_process_group("nccl", rank=rank, world_size=world_size)
85+
86+
device_mesh = init_device_mesh("cuda", (world_size,))
87+
torch.manual_seed(1)
88+
torch.cuda.set_device(rank)
89+
90+
yield device_mesh
91+
92+
dist.destroy_process_group()
93+
94+
7595
@pytest.mark.parametrize(
7696
"target_fqns",
7797
[
@@ -99,7 +119,12 @@
99119
},
100120
],
101121
)
102-
def test_moe_training_tp(target_fqns: list[str], compile: bool, recipe_config: dict):
122+
def test_moe_training_tp(
123+
target_fqns: list[str],
124+
compile: bool,
125+
recipe_config: dict,
126+
device_mesh_1d: DeviceMesh,
127+
):
103128
(
104129
recipe,
105130
group_alignment_size,
@@ -134,9 +159,6 @@ def test_moe_training_tp(target_fqns: list[str], compile: bool, recipe_config: d
134159
# or quantization ops (mxfp8 scaling groups are size 1x32)
135160
set_token_group_alignment_size_m(group_alignment_size)
136161

137-
# setup device mesh for fsdp + tp
138-
mesh = setup_distributed()
139-
140162
# define model args
141163
model_args = MoEArgs(
142164
num_experts=8,
@@ -178,10 +200,14 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
178200
model,
179201
target_fqns=target_fqns,
180202
)
203+
if compile:
204+
# TODO: compile with fullgraph=True when torchtitan llama4 moe supports it
205+
model = torch.compile(model, fullgraph=False)
206+
ref_model = torch.compile(ref_model, fullgraph=False)
181207

182208
# apply TP
183-
apply_moe_ep_tp(model, tp_mesh=mesh, ep_mesh=None, ep_tp_mesh=None)
184-
apply_moe_ep_tp(ref_model, tp_mesh=mesh, ep_mesh=None, ep_tp_mesh=None)
209+
apply_moe_ep_tp(model, tp_mesh=device_mesh_1d, ep_mesh=None, ep_tp_mesh=None)
210+
apply_moe_ep_tp(ref_model, tp_mesh=device_mesh_1d, ep_mesh=None, ep_tp_mesh=None)
185211

186212
# Rough validation that parallelization was applied properly.
187213
assert isinstance(model.experts.w1.data, DTensor), (
@@ -242,19 +268,6 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
242268
f"SQNR must be >= {min_param_grad_sqnr}, got {param_grad_sqnr.item()}."
243269
)
244270

245-
dist.destroy_process_group()
246-
247-
248-
def setup_distributed():
249-
rank = int(os.environ["RANK"])
250-
world_size = int(os.environ["WORLD_SIZE"])
251-
dist.init_process_group("nccl", rank=rank, world_size=world_size)
252-
device_mesh = init_device_mesh("cuda", (world_size,))
253-
# seed must be the same in all processes
254-
torch.manual_seed(1)
255-
torch.cuda.set_device(rank)
256-
return device_mesh
257-
258271

259272
def apply_moe_ep_tp(
260273
model: nn.Module,

test/prototype/moe_training/test_training.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,23 @@ def test_moe_training(target_fqns: list[str], compile: bool, recipe_config: dict
7373
recipe_config["min_input_grad_sqnr"],
7474
recipe_config["min_param_grad_sqnr"],
7575
)
76+
assert torch.cuda.is_available()
77+
if recipe == MoEScalingType.FP8_ROWWISE and torch.cuda.get_device_capability() != (
78+
9,
79+
0,
80+
):
81+
pytest.skip(
82+
f"Skipping FP8 rowwise tests, only supported on compute capability 9.0 and found {torch.cuda.get_device_capability()}"
83+
)
84+
85+
elif recipe == MoEScalingType.MXFP8 and torch.cuda.get_device_capability() != (
86+
10,
87+
0,
88+
):
89+
pytest.skip(
90+
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
91+
)
92+
7693
# Set token group alignment size. This is required so that
7794
# each logically distinct gemm in the grouped gemm `grad_weight = grad_output_t @ input`
7895
# has the contraction dim be divisible by 16. 16 byte alignment is required

0 commit comments

Comments
 (0)