Skip to content

Conversation

@tianyu-l
Copy link
Contributor

This PR:

  • let ExpertParallel handles indices permute / unpermute when EP is used
  • move to_local to model code to be more explicit
  • rename the expert_parallel wrapper which does permute / unpermute to indices_permutation_wrapper to be more accurate

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 11, 2025
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice refactor!

1 and 2 are needed only when expert_parallel_degree > 1.
3 is needed even for single-device computation.
2 can be moved to ExpertParallel _token_dispatch if not coupled with 3.
In order to use torch._grouped_mm, we need to make sure the number of
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This description only talks about padding, didn't talk about generate_permute_indices kernel to permute the inputs to be ordered by local experts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wrapper now is only responsible for padding, when EP is not used. I renamed to make it more clear.

)


class ExpertParallel(ParallelStyle):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have question about when we apply _permute() and _unpermute(), now they are applied in 2 places

  1. In ExpertParallel(), which is applied on transformer_block.moe.experts, so the input of MoE module will be reordered by local experts.

  2. When use_grouped_mm is enabled, in indices_permutation_wrapper, it will also try to permute the inputs of GroupedExperts by the order of local experts

Why do we need to apply is twice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They won't be applied together.

When EP is used, EP will do _permute and _unpermute.
When EP is not used, indices_padding_wrapper will do them.

@tianyu-l tianyu-l merged commit 9603872 into main Oct 12, 2025
8 checks passed
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 13, 2025
This PR:

- let `ExpertParallel` handles indices permute / unpermute when EP is
used
- move `to_local` to model code to be more explicit
- rename the `expert_parallel` wrapper which does permute / unpermute to
`indices_permutation_wrapper` to be more accurate
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 16, 2025
This PR:

- let `ExpertParallel` handles indices permute / unpermute when EP is
used
- move `to_local` to model code to be more explicit
- rename the `expert_parallel` wrapper which does permute / unpermute to
`indices_permutation_wrapper` to be more accurate
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 16, 2025
This PR:

- let `ExpertParallel` handles indices permute / unpermute when EP is
used
- move `to_local` to model code to be more explicit
- rename the `expert_parallel` wrapper which does permute / unpermute to
`indices_permutation_wrapper` to be more accurate
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 29, 2025
This PR:

- let `ExpertParallel` handles indices permute / unpermute when EP is
used
- move `to_local` to model code to be more explicit
- rename the `expert_parallel` wrapper which does permute / unpermute to
`indices_permutation_wrapper` to be more accurate
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 29, 2025
This PR:

- let `ExpertParallel` handles indices permute / unpermute when EP is
used
- move `to_local` to model code to be more explicit
- rename the `expert_parallel` wrapper which does permute / unpermute to
`indices_permutation_wrapper` to be more accurate
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ReordererSequenceParallel (ETP=1) doesn't work with torch.compile + AC

3 participants