-
Notifications
You must be signed in to change notification settings - Fork 8
Enabling split_dI_dW and split_fsdp_collectives passes #231
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
|
|
||
| import dataclasses | ||
| from contextlib import contextmanager | ||
| from copy import deepcopy | ||
| from functools import partial | ||
| from typing import Any | ||
|
|
||
|
|
@@ -49,12 +50,19 @@ class EpilogueInput(AOTOutput): | |
| pass | ||
|
|
||
|
|
||
| def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Graph]: | ||
| g_ins = g.find_nodes(op="placeholder") | ||
| def split_fsdp_prefetch( | ||
| gm: torch.fx.GraphModule, | ||
| num_params: int, | ||
| ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: | ||
| g = deepcopy(gm.graph) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious, why do you want to keep the original graph unchanged? Will it be further used?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we are using its container graph module to initialize it with the two new graphs, I just thought it would be safer to this way, we can remove this later |
||
| all_g_ins = g.find_nodes(op="placeholder") | ||
| param_g_ins = all_g_ins[:num_params] | ||
| rem_g_ins = all_g_ins[num_params:] | ||
|
|
||
| prefetch_g_outs_map = [] | ||
|
|
||
| for g_in in g_ins: | ||
| n = g_in | ||
| for param_g_in in param_g_ins: | ||
| n = param_g_in | ||
| last_ag = None | ||
| while True: | ||
| if len(n.users) != 1: | ||
|
|
@@ -66,7 +74,7 @@ def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Gra | |
| if is_all_gather_into_tensor(n): | ||
| last_ag = n | ||
| if last_ag is None: | ||
| prefetch_g_outs_map.append(g_in) | ||
| prefetch_g_outs_map.append(param_g_in) | ||
| else: | ||
| w_n = next(iter(last_ag.users)) | ||
| prefetch_g_outs_map.append(w_n) | ||
|
|
@@ -82,34 +90,42 @@ def split_fsdp_prefetch(g: torch.fx.Graph) -> tuple[torch.fx.Graph, torch.fx.Gra | |
| with exclude_wait_from_fx_side_effectful(): | ||
| prefetch_g = _extract_graph_with_inputs_outputs( | ||
| g, | ||
| g_ins, | ||
| param_g_ins, | ||
| prefetch_g_outs, | ||
| prefetch_g_outs_descs, | ||
| ignore_must_be_in_fw_bw=True, | ||
| ) | ||
|
|
||
| main_g = _extract_graph_with_inputs_outputs( | ||
| g, | ||
| prefetch_g_outs, | ||
| prefetch_g_outs + rem_g_ins, | ||
| g_outs, | ||
| g_outs_descs, | ||
| ignore_must_be_in_fw_bw=True, | ||
| ) | ||
| return prefetch_g, main_g | ||
| prefetch_gm = torch.fx._lazy_graph_module._make_graph_module(gm, prefetch_g) | ||
| main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) | ||
| return prefetch_gm, main_gm | ||
|
|
||
|
|
||
| def split_fsdp_reduce_scatters_epilogue( | ||
| g: torch.fx.Graph, | ||
| ) -> tuple[torch.fx.Graph, torch.fx.Graph]: | ||
| gm: torch.fx.GraphModule, | ||
| num_grads: int, | ||
| ) -> tuple[torch.fx.GraphModule, torch.fx.GraphModule]: | ||
| g = deepcopy(gm.graph) | ||
| g_ins = g.find_nodes(op="placeholder") | ||
| g_outs = pytree.arg_tree_leaves(*(n.args for n in g.find_nodes(op="output"))) | ||
| g_outs_descs = pytree.arg_tree_leaves( | ||
| next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(g_outs)) | ||
| grad_outs = g_outs[:num_grads] | ||
| rem_g_outs = g_outs[num_grads:] | ||
| out_descs = pytree.arg_tree_leaves( | ||
| next(iter(g.find_nodes(op="output"))).meta.get("desc", [None] * len(grad_outs)) | ||
| ) | ||
| grad_outs_descs = out_descs[:num_grads] | ||
| rem_g_outs_descs = out_descs[num_grads:] | ||
|
|
||
| g_outs_map = [] | ||
| for g_out in g_outs: | ||
| n = g_out | ||
| grad_outs_map = [] | ||
| for grad_out in grad_outs: | ||
| n = grad_out | ||
| last_rs = None | ||
| while n is not None: | ||
| if len(n.all_input_nodes) != 1: | ||
|
|
@@ -124,27 +140,28 @@ def split_fsdp_reduce_scatters_epilogue( | |
| # The reduction of gradients happen in multiple steps | ||
| last_rs = n | ||
| if last_rs is not None: | ||
| g_outs_map.append(last_rs) | ||
| grad_outs_map.append(last_rs) | ||
| else: | ||
| g_outs_map.append(g_out) | ||
| grad_outs_map.append(grad_out) | ||
|
|
||
| epi_g_ins = [n for n in g_outs_map if n is not None] | ||
| epi_g_ins = grad_outs_map | ||
| epi_g_ins_descs: list[AOTOutput] = [EpilogueInput() for _ in range(len(epi_g_ins))] | ||
|
|
||
| with exclude_wait_from_fx_side_effectful(): | ||
| main_g = _extract_graph_with_inputs_outputs( | ||
| g, | ||
| g_ins, | ||
| epi_g_ins, | ||
| epi_g_ins_descs, | ||
| epi_g_ins + rem_g_outs, | ||
| epi_g_ins_descs + rem_g_outs_descs, | ||
| ignore_must_be_in_fw_bw=True, | ||
| ) | ||
| epi_g = _extract_graph_with_inputs_outputs( | ||
| g, | ||
| epi_g_ins, | ||
| g_outs, | ||
| g_outs_descs, | ||
| grad_outs, | ||
| grad_outs_descs, | ||
| ignore_must_be_in_fw_bw=True, | ||
| ) | ||
|
|
||
| return main_g, epi_g | ||
| epi_gm = torch.fx._lazy_graph_module._make_graph_module(gm, epi_g) | ||
| main_gm = torch.fx._lazy_graph_module._make_graph_module(gm, main_g) | ||
| return main_gm, epi_gm | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: If we use export with descriptors, potentially num_params could be taken from metadata.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, for now this is easily obtainable from graph meta. Same as above can be removed later.