-
Notifications
You must be signed in to change notification settings - Fork 8
Enabling split_dI_dW and split_fsdp_collectives passes #231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
splti_fsdp_prefetch lgtm
| gm: torch.fx.GraphModule, | ||
| num_params: int, | ||
| ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: | ||
| g = deepcopy(gm.graph) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Curious, why do you want to keep the original graph unchanged? Will it be further used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we are using its container graph module to initialize it with the two new graphs, I just thought it would be safer to this way, we can remove this later
| g_ins = g.find_nodes(op="placeholder") | ||
| def split_fsdp_prefetch( | ||
| gm: torch.fx.GraphModule, | ||
| num_params: int, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: If we use export with descriptors, potentially num_params could be taken from metadata.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, for now this is easily obtainable from graph meta. Same as above can be removed later.
710e1a6 to
40ad911
Compare
40ad911 to
19ac5cb
Compare
Based off #227, will rebase once it lands.
Test file:
examples/example_pp_graph_passes.pyMade some modifications to Ivan’s pass (#201) as follows:
The split_fsdp_prefetch pass required all the inputs (params+buffers+microbatch) to be passed to the prefetch graph and required all the outputs of the prefetch graph to be passed to the fwd_graph. But we wanted something slightly different. We only want to pass the sharded_params to the prefetch_graph and obtain unsharded_params from it. The fwd_graph then takes in these unsharded_params (obtained from prefetch_graph)+buffers+microbatch. This is important because we want the prefetch graph to be independent of micro batch since the unshard action won’t have an associated micro batch.
Similarly, reduce_grad graph should only take in the unsharded_grads from the bwd_graph and produce sharded_grads. We cannot pass in all the outputs of the bwd_graph to the reduce_grad graph, since it would make it micro batch dependent and we want to call reduce_grad action only once after accumulating unsharded_grads across micro batches.
Brian’s pass (#232) worked perfectly fine, just integrated it differently in the end-to-end workflow.