Skip to content

Commit 45fb201

Browse files
[moe training] add test case for shared expert in distributed tests
stack-info: PR: #2856, branch: danielvegamyhre/stack/56
1 parent 253d65a commit 45fb201

File tree

7 files changed

+19
-15
lines changed

7 files changed

+19
-15
lines changed

test/prototype/moe_training/test_everything.sh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ IS_ROCM=$(rocm-smi --version || true)
1212
# These tests do not work on ROCm yet
1313
if [ -z "$IS_ROCM" ]
1414
then
15-
./test/prototype/moe_training/test_fsdp.sh
16-
./test/prototype/moe_training/test_tp.sh
17-
./test/prototype/moe_training/test_fsdp_tp.sh
15+
./test_fsdp.sh
16+
./test_tp.sh
17+
./test_fsdp_tp.sh
1818
fi
1919

2020
echo "all tests successful"

test/prototype/moe_training/test_fsdp.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#
88
# To run these unit tests, use the following command:
99
#
10-
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp.py
10+
# torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test_fsdp.py
1111
#
1212
#######################################################################
1313

@@ -45,7 +45,14 @@
4545
)
4646

4747

48-
def test_moe_float8_training_fsdp():
48+
@pytest.mark.parametrize(
49+
"target_fqns",
50+
[
51+
["experts"],
52+
["experts,shared_expert"],
53+
],
54+
)
55+
def test_moe_float8_training_fsdp(target_fqns: list[str]):
4956
assert torch.cuda.is_available()
5057

5158
# setup distributed for fsdp
@@ -55,7 +62,6 @@ def test_moe_float8_training_fsdp():
5562
set_token_group_alignment_size_m(16)
5663

5764
# define model args
58-
target_fqns = ["experts"]
5965
model_args = MoEArgs(
6066
num_experts=8,
6167
)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp.py -s
1+
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test_fsdp.py -s

test/prototype/moe_training/test_fsdp_tp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#
88
# To run these unit tests, use the following command:
99
#
10-
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_fsdp_tp.py
10+
# torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test_fsdp_tp.py
1111
#
1212
#######################################################################
1313

@@ -67,8 +67,7 @@
6767
"target_fqns",
6868
[
6969
["experts"],
70-
# TODO: investigate hang when shared_expert is converted
71-
# ["experts,shared_expert"],
70+
["experts,shared_expert"],
7271
],
7372
)
7473
def test_moe_float8_training_fsdp_tp(target_fqns: list[str]):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torchrun --nproc_per_node=4 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_fsdp_tp.py -s
1+
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test_fsdp_tp.py -s

test/prototype/moe_training/test_tp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#
88
# To run these unit tests, use the following command:
99
#
10-
# torchrun --nproc_per_node=${NUM_GPUS} -m pytest test_tp.py
10+
# torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test_tp.py
1111
#
1212
#######################################################################
1313

@@ -67,8 +67,7 @@
6767
"target_fqns",
6868
[
6969
["experts"],
70-
# TODO: investigate hang when shared_expert is converted
71-
# ["experts,shared_expert"],
70+
["experts,shared_expert"],
7271
],
7372
)
7473
def test_moe_float8_training_tp(target_fqns: list[str]):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test/prototype/moe_training/test_tp.py -s
1+
torchrun --nproc_per_node=2 --local-ranks-filter=0 -m pytest test_tp.py -s

0 commit comments

Comments
 (0)