Skip to content

Commit 59ecee0

Browse files
committed
fix ep
1 parent da672b2 commit 59ecee0

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

torchtitan/experiments/gpt_oss/infra/parallelize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from torchtitan.models.llama4.infra.parallelize import apply_fsdp
3030
from torchtitan.tools.logging import logger
3131

32-
from .expert_parallel import GptossExpertTensorParallel
32+
from .expert_parallel import GptossExpertTensorParallel, GptossTensorParallel
3333

3434

3535
# for selective op activation checkpointing
@@ -304,11 +304,11 @@ def apply_moe_ep_tp(
304304
if ep_mesh is None:
305305
experts_mesh = tp_mesh
306306
# input Replicate, output Partial
307-
experts_plan = TensorParallel()
307+
experts_plan = GptossTensorParallel()
308308
elif tp_mesh is None:
309309
experts_mesh = ep_mesh
310310
# input / output sharding on the batch / tokens dim
311-
experts_plan = GptossExpertParallel()
311+
experts_plan = ExpertParallel()
312312
elif etp_enabled:
313313
experts_mesh = ep_tp_mesh
314314
experts_plan = GptossExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh)

0 commit comments

Comments
 (0)