-
Notifications
You must be signed in to change notification settings - Fork 598
[MoE/EP] apply dim-1 FSDP sharding for routed experts and rewrite shared experts with FFN #1561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi @tianyu-l I am wondering if we really need shard(1) or at least we can have an option to choose shard(0) or shard(1). another reason is FSDP2 + shard(0) naturally work for second order optimizer (for example muon which is already used in Kimi K2 and GLM 4.5. Because on each rank we have full experts matrix such that we can directly run SVD or orthogonal methods without extra communication, but with shard(1) we need extra sync. ( we can PR the disturbed version of muon/scion into torchtian if you are interested). [we could do benchmark and sweep of training configurations on torchtian smth like this ] |
|
@rakkit I think we can make it as configurable and still default to
I think if one wants HSDP behavior it'd better be explicit, but that means non-EP part also has to share the same replicate degree, unless we extend the
Sure, but what about non-experts params, e.g. those in MLP, Attention -- sync sounds necessary for them?
This would be appreciated, although we still have some optimizations ongoing, so maybe let's do that later. |
|
Hi @tianyu-l thanks a lot.
|
If we "force shard(1) if fsdp*ep is large", then we probably don't need to make it configurable. Is that OK?
I think this holds for DSV3 where there are 128 thin experts. For other archs, e.g. 8 wide experts, it no longer holds. But yeah thanks for the input, with updated solution we should be good for both.
What flag did you use,
I'm planning to add support for NVSHMEM-based a2a, trying to remove all the d2h syncs
That would be super interesting. Happy to learn more. Maybe we can start with an RFC to bring people on board?
Again I appreciated your input and your work a lot! Also on our side Maybe it's worth syncing (are you on Slack?). |
|
Thanks a lot @tianyu-l It's totally fine to force shard(1) w/o a configurable. (fsdp*ep > experts) The API I mean Thanks a lot to you and torchtitan teams for bringing these amazing features! I am not on Slack - ). |
d9ab874 to
4feca04
Compare
updated the code and PR summary
sounds worth filing a separate GH issue and tag our FSDP PoC @weifengpy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm understanding, this is upper-bounding the max padding right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
sanketpurandare
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a neat change. Also makes it consistent with all of FSDP sharding in the sense that, all params are sharded on their first dimension and for routed experts, dim-1 is the real first dimension.
…red experts with FFN
apply dim-1 FSDP sharding for routed experts when
dp_mod_ep * ep > num_expertsThis is because our routed experts are defined of shape
(num_experts, ..., ...). EP already shards on dim-0. FSDP's default dim-0 sharding + EP sharding will be inefficient whendp_mod_ep * ep > num_experts.Tested:
with 8 experts FSDP2 EP4, we see default dim-0 sharding
with 4 experts, FSDP2 EP4, we see dim-1 sharding
also tested integration works fine with: FSDP 2, CP 2 (EP 2), TP 2 (ETP 2)
rewrite shared experts with FFN
This is because
sharding_placement_fn.hidden_dimdimension, and TP will just work out fine.other changes
shared_experttoshared_expertstolist()d2h forinput_splitsandoutput_splitsintoken_dispatchinto one