-
Notifications
You must be signed in to change notification settings - Fork 8
Pass to split all_gather prologue and reduce_scatter prologue from fsdp graph #201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
stack-info: PR: #201, branch: IvanKobzarev/stack/9
0159d8a to
68e7956
Compare
stack-info: PR: #201, branch: IvanKobzarev/stack/9
68e7956 to
a6dd593
Compare
stack-info: PR: #201, branch: IvanKobzarev/stack/9
7eb06ad to
0571cef
Compare
stack-info: PR: #201, branch: IvanKobzarev/stack/9
stack-info: PR: #201, branch: IvanKobzarev/stack/9
0571cef to
415b736
Compare
| if len(n.users) != 1: | ||
| break |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes that the partitioner has been changed to that we don't recompute the all-gather collectives in the backward pass.
@sanketpurandare you'll need to keep this in mind for your PR
…dp graph stack-info: PR: #201, branch: IvanKobzarev/stack/9
415b736 to
c098c92
Compare
…dp graph stack-info: PR: #201, branch: IvanKobzarev/stack/9
c098c92 to
18ccbec
Compare
| g_ins = g.find_nodes(op="placeholder") | ||
| prefetch_g_outs_map = [] | ||
|
|
||
| for g_in in g_ins: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one thing I found a bit confusing when running this locally is that we are currently moving "views of inputs" into the prefetch subgraph. In my local example:
full backward graph: P2013183321
prefetch subgraph: P2013183365
remaining subgraph: P2013183449
you can see that tangents_1 is not an input in the "remaining subgraph", which is surprising because we should not be performing any FSDP collectives directly on tangents_1. And it looks like this is because there is a view(tangents_1) in the main backward, that we end up moving into the prefetch subgraph (technically harmless but confusing).
…dp graph stack-info: PR: #201, branch: IvanKobzarev/stack/9
Stacked PRs:
Pass to split all_gather prologue and reduce_scatter prologue from fsdp graph