Skip to content

Conversation

@tianyu-l
Copy link
Contributor

@tianyu-l tianyu-l commented Aug 13, 2025

apply dim-1 FSDP sharding for routed experts when dp_mod_ep * ep > num_experts
This is because our routed experts are defined of shape (num_experts, ..., ...). EP already shards on dim-0. FSDP's default dim-0 sharding + EP sharding will be inefficient when dp_mod_ep * ep > num_experts.

Tested:
with 8 experts FSDP2 EP4, we see default dim-0 sharding

[rank0]:w1 DTensor(local_tensor=tensor(..., device='meta', size=(1, 512, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0)))
[rank0]:w2 DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 512)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0)))
[rank0]:w3 DTensor(local_tensor=tensor(..., device='meta', size=(1, 512, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(_StridedShard(dim=0, sf=4), Shard(dim=0)))

with 4 experts, FSDP2 EP4, we see dim-1 sharding

[rank0]:w1 DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0)))
[rank0]:w2 DTensor(local_tensor=tensor(..., device='meta', size=(1, 128, 512)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0)))
[rank0]:w3 DTensor(local_tensor=tensor(..., device='meta', size=(1, 256, 256)), device_mesh=DeviceMesh((dp_shard_mod_ep=2, ep=4), device: 'cuda', stride: (4, 1)), placements=(Shard(dim=1), Shard(dim=0)))

also tested integration works fine with: FSDP 2, CP 2 (EP 2), TP 2 (ETP 2)

rewrite shared experts with FFN
This is because

  • Same reason above, but using FFN is a simpler solution, especially considering shared experts are sharded together with TransformerBlock, so no need to complicate its sharding_placement_fn.
  • It turns out for multiple shared experts, we can just stack on the hidden_dim dimension, and TP will just work out fine.
  • It also simplifies the GroupedExperts module as it no longer needs to work with shared experts.

other changes

  • rename shared_expert to shared_experts
  • merge two tolist() d2h for input_splits and output_splits in token_dispatch into one
  • state dict / checkpoint conversion changes (@wwwjn please help verify)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 13, 2025
@rakkit
Copy link
Contributor

rakkit commented Aug 13, 2025

Hi @tianyu-l I am wondering if we really need shard(1) or at least we can have an option to choose shard(0) or shard(1).
We have two reasons for keep shard(0) though we will facing the problem of fsdp size > experts. First reason is for people who need to train big MoE model with EP, it naturally goes to EP+PP regime to overlap the all2all. For 64 experts or 32 experts case, even on 512 GPUs or 1024 GPUs, (consider pp=4 for example) we can easier avoid the issue with dp=2 or dp=4. The state sharing of fsdp on this size has no big diff, and with hsdp can even reduce the latency in communication.

another reason is FSDP2 + shard(0) naturally work for second order optimizer (for example muon which is already used in Kimi K2 and GLM 4.5. Because on each rank we have full experts matrix such that we can directly run SVD or orthogonal methods without extra communication, but with shard(1) we need extra sync. ( we can PR the disturbed version of muon/scion into torchtian if you are interested).

[we could do benchmark and sweep of training configurations on torchtian smth like this ]

@tianyu-l
Copy link
Contributor Author

@rakkit
Thanks a lot for the input!!

I think we can make it as configurable and still default to Shard(0). I have thought about the alternative of adding condition tests to do Shard(1) only when dp_mod_ep * ep > num_experts (not configurable). Do you think that's also OK?

The state sharing of fsdp on this size has no big diff, and with hsdp can even reduce the latency in communication.

I think if one wants HSDP behavior it'd better be explicit, but that means non-EP part also has to share the same replicate degree, unless we extend the ParallelDims further.

Because on each rank we have full experts matrix such that we can directly run SVD or orthogonal methods without extra communication

Sure, but what about non-experts params, e.g. those in MLP, Attention -- sync sounds necessary for them?

[we could do benchmark and sweep of training configurations on torchtian smth like this ]

This would be appreciated, although we still have some optimizations ongoing, so maybe let's do that later.

@rakkit
Copy link
Contributor

rakkit commented Aug 13, 2025

Hi @tianyu-l thanks a lot.

  1. I think your solution is perfect, make it configurable and force shard(1) if fsdp*ep is large (sry brings you extra work)
  2. HSDP here I mean we use globally DP+EP+FSDP. We already have this. I was trying to saying, for people really need to train with EP, on the limits of FSDP*EP=experts. Increasing the sharding (via shard(1) further) may brings negligible memory saving, and turns out increased latency on comm due to large world size.
  3. For distributed (muon, scion) yes, in general we need two extra all-to-all for non-experts parameters, and maybe extra gather when TP enabled. But with shard(0) we do not need any sync for EP parameters. [ps, I tried fsdp2’s flag for no gradient shard, -> try to simulate Zaro-1. which used in PP, but gradient is still sharding after backward. Anyhow, from what we observed, the cost of this all-to-all is acceptable.
  4. I totally agree, I will soon start to check out PP/EP overlap, and maybe start benchmark after fix compile for EP. (We are research institution, all of benchmarks 💯 are open to share)

@tianyu-l
Copy link
Contributor Author

@rakkit

I think your solution is perfect, make it configurable and force shard(1) if fsdp*ep is large (sry brings you extra work)

If we "force shard(1) if fsdp*ep is large", then we probably don't need to make it configurable. Is that OK?

HSDP here I mean we use globally DP+EP+FSDP. We already have this. I was trying to saying, for people really need to train with EP, on the limits of FSDP*EP=experts. Increasing the sharding (via shard(1) further) may brings negligible memory saving, and turns out increased latency on comm due to large world size.

I think this holds for DSV3 where there are 128 thin experts. For other archs, e.g. 8 wide experts, it no longer holds. But yeah thanks for the input, with updated solution we should be good for both.

[ps, I tried fsdp2’s flag for no gradient shard, -> try to simulate Zaro-1. which used in PP, but gradient is still sharding after backward.

What flag did you use, reshard_after_forward? That doesn't mean we don't shard gradients. It means after forward we don't shard params, so in FSDP doesn't do all-gather again.

I totally agree, I will soon start to check out PP/EP overlap, and maybe start benchmark after fix compile for EP.

I'm planning to add support for NVSHMEM-based a2a, trying to remove all the d2h syncs
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/kernels/moe/dispatch.py

we can PR the disturbed version of muon/scion into torchtian if you are interested

That would be super interesting. Happy to learn more. Maybe we can start with an RFC to bring people on board?

(We are research institution, all of benchmarks 💯 are open to share)

Again I appreciated your input and your work a lot!

Also on our side

Maybe it's worth syncing (are you on Slack?).

@rakkit
Copy link
Contributor

rakkit commented Aug 13, 2025

Thanks a lot @tianyu-l

It's totally fine to force shard(1) w/o a configurable. (fsdp*ep > experts)

The API I mean set_reshard_after_backward, which is widely used in PP, that I hope to achieve smth like Zero-1 after fwd/bwd to allow the optimizer to get full gradient. (I still doubt if we do need this for dist-muon/scion or not)

Thanks a lot to you and torchtitan teams for bringing these amazing features! I am not on Slack - ).

@tianyu-l tianyu-l force-pushed the fsdp branch 2 times, most recently from d9ab874 to 4feca04 Compare August 13, 2025 21:49
@tianyu-l
Copy link
Contributor Author

@rakkit

It's totally fine to force shard(1) w/o a configurable. (fsdp*ep > experts)

updated the code and PR summary

The API I mean set_reshard_after_backward, which is widely used in PP, that I hope to achieve smth like Zero-1 after fwd/bwd to allow the optimizer to get full gradient.

sounds worth filing a separate GH issue and tag our FSDP PoC @weifengpy

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm understanding, this is upper-bounding the max padding right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor

@sanketpurandare sanketpurandare left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a neat change. Also makes it consistent with all of FSDP sharding in the sense that, all params are sharded on their first dimension and for routed experts, dim-1 is the real first dimension.

@tianyu-l tianyu-l merged commit 7354848 into main Aug 14, 2025
7 checks passed
@tianyu-l tianyu-l deleted the fsdp branch August 14, 2025 00:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants